Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
473a045a
Commit
473a045a
authored
Mar 08, 2019
by
Paul
Browse files
Fix memory conflicts
parent
84221940
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
97 additions
and
45 deletions
+97
-45
src/include/migraphx/iterator_for.hpp
src/include/migraphx/iterator_for.hpp
+54
-5
src/schedule.cpp
src/schedule.cpp
+38
-38
test/schedule_test.cpp
test/schedule_test.cpp
+5
-2
No files found.
src/include/migraphx/iterator_for.hpp
View file @
473a045a
...
@@ -3,21 +3,64 @@
...
@@ -3,21 +3,64 @@
#include <cassert>
#include <cassert>
#include <type_traits>
#include <type_traits>
#include <iterator>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
class
T
>
struct
iterator_for_select
{
template
<
class
T
>
static
T
deref
(
T
x
)
{
return
x
;
}
template
<
class
T
>
static
auto
begin
(
T
*
x
)
{
return
x
->
begin
();
}
template
<
class
T
>
static
auto
end
(
T
*
x
)
{
return
x
->
end
();
}
};
struct
iterator_for_select_reverse
{
template
<
class
T
>
static
auto
deref
(
T
x
)
{
return
std
::
prev
(
x
.
base
());
}
template
<
class
T
>
static
auto
begin
(
T
*
x
)
{
return
std
::
make_reverse_iterator
(
x
->
end
());
}
template
<
class
T
>
static
auto
end
(
T
*
x
)
{
return
std
::
make_reverse_iterator
(
x
->
begin
());
}
};
template
<
class
T
,
class
Selector
=
iterator_for_select
>
struct
iterator_for_range
struct
iterator_for_range
{
{
T
*
base
;
T
*
base
;
using
base_iterator
=
std
::
remove_reference_t
<
decltype
(
base
->
begin
())
>
;
using
base_iterator
=
std
::
remove_reference_t
<
decltype
(
Selector
::
begin
(
base
))
>
;
struct
iterator
struct
iterator
{
{
base_iterator
i
;
base_iterator
i
;
base_iter
ato
r
operator
*
()
const
{
return
i
;
}
a
u
to
operator
*
()
const
{
return
Selector
::
deref
(
i
)
;
}
base_iterator
operator
++
()
{
return
++
i
;
}
base_iterator
operator
++
()
{
return
++
i
;
}
bool
operator
!=
(
const
iterator
&
rhs
)
const
{
return
i
!=
rhs
.
i
;
}
bool
operator
!=
(
const
iterator
&
rhs
)
const
{
return
i
!=
rhs
.
i
;
}
};
};
...
@@ -25,12 +68,12 @@ struct iterator_for_range
...
@@ -25,12 +68,12 @@ struct iterator_for_range
iterator
begin
()
iterator
begin
()
{
{
assert
(
base
!=
nullptr
);
assert
(
base
!=
nullptr
);
return
{
base
->
begin
()};
return
{
Selector
::
begin
(
base
)};
}
}
iterator
end
()
iterator
end
()
{
{
assert
(
base
!=
nullptr
);
assert
(
base
!=
nullptr
);
return
{
base
->
end
(
)};
return
{
Selector
::
end
(
base
)};
}
}
};
};
template
<
class
T
>
template
<
class
T
>
...
@@ -39,6 +82,12 @@ iterator_for_range<T> iterator_for(T& x)
...
@@ -39,6 +82,12 @@ iterator_for_range<T> iterator_for(T& x)
return
{
&
x
};
return
{
&
x
};
}
}
template
<
class
T
>
iterator_for_range
<
T
,
iterator_for_select_reverse
>
reverse_iterator_for
(
T
&
x
)
{
return
{
&
x
};
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/schedule.cpp
View file @
473a045a
...
@@ -162,7 +162,7 @@ struct stream_info
...
@@ -162,7 +162,7 @@ struct stream_info
}
}
template
<
class
Selector
>
template
<
class
Selector
>
auto
get_streams
(
instruction_ref
start
,
Selector
select
)
const
auto
get_streams
_from
(
instruction_ref
start
,
Selector
select
)
const
{
{
return
[
=
](
auto
f
)
{
return
[
=
](
auto
f
)
{
return
fix
<
bool
>
([
&
](
auto
self
,
auto
ins
)
{
return
fix
<
bool
>
([
&
](
auto
self
,
auto
ins
)
{
...
@@ -184,16 +184,28 @@ struct stream_info
...
@@ -184,16 +184,28 @@ struct stream_info
};
};
}
}
std
::
unordered_set
<
std
::
size_t
>
get_streams
(
instruction_ref
ins
)
{
if
(
has_stream
(
ins
))
return
{
get_stream
(
ins
)};
std
::
unordered_set
<
std
::
size_t
>
result
;
get_streams_from
(
ins
,
get_inputs
())([
&
](
auto
s
)
{
result
.
insert
(
s
);
return
true
;
});
return
result
;
}
template
<
class
...
Ts
>
template
<
class
...
Ts
>
bool
is_merge_point
(
instruction_ref
ins
,
Ts
...
xs
)
const
bool
is_merge_point
(
instruction_ref
ins
,
Ts
...
xs
)
const
{
{
return
different
(
get_streams
(
ins
,
get_inputs
()),
xs
...);
return
different
(
get_streams
_from
(
ins
,
get_inputs
()),
xs
...);
}
}
template
<
class
...
Ts
>
template
<
class
...
Ts
>
bool
is_split_point
(
instruction_ref
ins
,
Ts
...
xs
)
const
bool
is_split_point
(
instruction_ref
ins
,
Ts
...
xs
)
const
{
{
return
different
(
get_streams
(
ins
,
get_outputs
()),
xs
...);
return
different
(
get_streams
_from
(
ins
,
get_outputs
()),
xs
...);
}
}
std
::
vector
<
instruction_ref
>
get_recorded_instructions
(
instruction_ref
start
)
std
::
vector
<
instruction_ref
>
get_recorded_instructions
(
instruction_ref
start
)
...
@@ -225,7 +237,7 @@ struct stream_info
...
@@ -225,7 +237,7 @@ struct stream_info
std
::
vector
<
std
::
size_t
>
wait_for
(
instruction_ref
ins
)
const
std
::
vector
<
std
::
size_t
>
wait_for
(
instruction_ref
ins
)
const
{
{
std
::
vector
<
std
::
size_t
>
result
;
std
::
vector
<
std
::
size_t
>
result
;
get_streams
(
ins
,
get_inputs
())([
&
](
auto
s
)
{
get_streams
_from
(
ins
,
get_inputs
())([
&
](
auto
s
)
{
result
.
push_back
(
s
);
result
.
push_back
(
s
);
return
true
;
return
true
;
});
});
...
@@ -243,35 +255,27 @@ struct stream_info
...
@@ -243,35 +255,27 @@ struct stream_info
find_concurrent_instructions
(
program
&
p
)
find_concurrent_instructions
(
program
&
p
)
{
{
std
::
unordered_map
<
instruction_ref
,
std
::
vector
<
std
::
vector
<
instruction_ref
>>>
result
;
std
::
unordered_map
<
instruction_ref
,
std
::
vector
<
std
::
vector
<
instruction_ref
>>>
result
;
std
::
unordered_map
<
instruction_ref
,
std
::
unordered_set
<
instruction_ref
>>
split
_from
;
std
::
unordered_map
<
instruction_ref
,
std
::
unordered_set
<
instruction_ref
>>
merge
_from
;
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
reverse_
iterator_for
(
p
))
{
{
if
(
iweights
[
ins
]
==
0
)
for
(
auto
&&
arg
:
ins
->
outputs
())
continue
;
for
(
auto
&&
arg
:
ins
->
inputs
())
{
{
if
(
is_
split
_point
(
arg
))
if
(
is_
merge
_point
(
arg
))
split
_from
[
ins
].
insert
(
arg
);
merge
_from
[
ins
].
insert
(
arg
);
split
_from
[
ins
].
insert
(
split
_from
[
arg
].
begin
(),
split
_from
[
arg
].
end
());
merge
_from
[
ins
].
insert
(
merge
_from
[
arg
].
begin
(),
merge
_from
[
arg
].
end
());
}
}
auto
stream
=
get_stream
(
ins
);
auto
streams
=
get_streams
(
ins
);
// if (is_merge_point(ins))
// {
// Collect concur instructions for each merge point.
// // post-dominator kills split point.
for
(
auto
&
merge
:
merge_from
[
ins
])
// for(auto& split : split_from[ins])
// {
// if(strictly_post_dominates(ins, split))
// split_from[ins].erase(split);
// }
// }
// Collect concur instructions for each split point.
for
(
auto
&
split
:
split_from
[
ins
])
{
{
if
(
result
[
split
].
size
()
<=
stream
)
for
(
auto
stream
:
streams
)
result
[
split
].
resize
(
stream
+
1
);
{
result
[
split
][
stream
].
push_back
(
ins
);
if
(
result
[
merge
].
size
()
<=
stream
)
result
[
merge
].
resize
(
stream
+
1
);
result
[
merge
][
stream
].
push_back
(
ins
);
}
}
}
}
}
return
result
;
return
result
;
...
@@ -304,7 +308,7 @@ void schedule::apply(program& p) const
...
@@ -304,7 +308,7 @@ void schedule::apply(program& p) const
std
::
cout
<<
":"
;
std
::
cout
<<
":"
;
std
::
cout
<<
" weight="
<<
si
.
weights
.
at
(
ins
);
std
::
cout
<<
" weight="
<<
si
.
weights
.
at
(
ins
);
std
::
cout
<<
" input={"
;
std
::
cout
<<
" input={"
;
si
.
get_streams
(
ins
,
get_inputs
())([
&
](
auto
s
)
{
si
.
get_streams
_from
(
ins
,
get_inputs
())([
&
](
auto
s
)
{
std
::
cout
<<
s
<<
","
;
std
::
cout
<<
s
<<
","
;
return
true
;
return
true
;
});
});
...
@@ -367,20 +371,16 @@ void schedule::apply(program& p) const
...
@@ -367,20 +371,16 @@ void schedule::apply(program& p) const
// Add memory conflicts
// Add memory conflicts
auto
concur_ins
=
si
.
find_concurrent_instructions
(
p
);
auto
concur_ins
=
si
.
find_concurrent_instructions
(
p
);
for
(
auto
&&
split
:
concur_ins
)
for
(
auto
&&
merge
:
concur_ins
)
{
{
dfor
(
split
.
second
.
size
(),
split
.
second
.
size
())([
&
](
auto
i
,
auto
j
)
{
dfor
(
merge
.
second
.
size
(),
merge
.
second
.
size
())([
&
](
auto
i
,
auto
j
)
{
if
(
i
==
j
)
if
(
i
==
j
)
return
;
return
;
for
(
auto
ins1
:
split
.
second
[
i
])
for
(
auto
ins1
:
merge
.
second
[
i
])
{
{
auto
args
=
split
.
second
[
j
];
auto
args
=
merge
.
second
[
j
];
args
.
insert
(
args
.
begin
(),
ins1
);
args
.
insert
(
args
.
begin
(),
ins1
);
p
.
insert_instruction
(
merge
.
first
,
op
::
identity
{},
args
);
auto
point
=
std
::
max_element
(
args
.
begin
(),
args
.
end
(),
[
&
](
auto
x
,
auto
y
)
{
return
std
::
distance
(
split
.
first
,
x
)
<
std
::
distance
(
split
.
first
,
y
);
});
p
.
insert_instruction
(
std
::
next
(
*
point
),
op
::
identity
{},
args
);
}
}
});
});
}
}
...
...
test/schedule_test.cpp
View file @
473a045a
...
@@ -94,7 +94,7 @@ struct schedule_model_test
...
@@ -94,7 +94,7 @@ struct schedule_model_test
}
}
(
*
ins2wait_for
)[
ins
]
->
push_back
(
wait2stream
->
at
(
wait_id
));
(
*
ins2wait_for
)[
ins
]
->
push_back
(
wait2stream
->
at
(
wait_id
));
}
}
void
record
(
migraphx
::
program
&
p
,
migraphx
::
instruction_ref
ins
,
std
::
size_t
wait_id
)
const
void
record
(
migraphx
::
program
&
,
migraphx
::
instruction_ref
ins
,
std
::
size_t
wait_id
)
const
{
{
(
*
wait2stream
)[
wait_id
]
=
ins2stream
->
at
(
ins
);
(
*
wait2stream
)[
wait_id
]
=
ins2stream
->
at
(
ins
);
}
}
...
@@ -181,6 +181,9 @@ std::vector<std::size_t> get_wait_for(std::size_t wait_on, std::vector<std::size
...
@@ -181,6 +181,9 @@ std::vector<std::size_t> get_wait_for(std::size_t wait_on, std::vector<std::size
std
::
vector
<
std
::
size_t
>
get_wait_for
(
migraphx
::
instruction_ref
ins
)
std
::
vector
<
std
::
size_t
>
get_wait_for
(
migraphx
::
instruction_ref
ins
)
{
{
auto
wait_ins
=
std
::
prev
(
ins
);
auto
wait_ins
=
std
::
prev
(
ins
);
// Skip identity operators
while
(
wait_ins
->
name
()
==
"identity"
)
wait_ins
=
std
::
prev
(
wait_ins
);
if
(
wait_ins
->
name
()
!=
"wait_event"
)
if
(
wait_ins
->
name
()
!=
"wait_event"
)
return
{};
return
{};
auto
wf
=
*
migraphx
::
any_cast
<
wait_event
>
(
wait_ins
->
get_operator
()).
wait_for
;
auto
wf
=
*
migraphx
::
any_cast
<
wait_event
>
(
wait_ins
->
get_operator
()).
wait_for
;
...
@@ -338,7 +341,7 @@ TEST_CASE(double_entry)
...
@@ -338,7 +341,7 @@ TEST_CASE(double_entry)
EXPECT
(
t
.
get_stream
(
binary
)
==
0
);
EXPECT
(
t
.
get_stream
(
binary
)
==
0
);
EXPECT
(
get_wait_for
(
binary
)
==
EXPECT
(
get_wait_for
(
binary
)
==
get_wait_for
(
t
.
get_stream
(
binary
),
{
t
.
get_stream
(
onep
),
t
.
get_stream
(
twop
)}));
get_wait_for
(
t
.
get_stream
(
binary
),
{
t
.
get_stream
(
onep
),
t
.
get_stream
(
twop
)}));
//
EXPECT(check_conflicts(p, onep, twop));
EXPECT
(
check_conflicts
(
p
,
onep
,
twop
));
}
}
TEST_CASE
(
two_branches
)
TEST_CASE
(
two_branches
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment