Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
4008675f
Commit
4008675f
authored
Mar 11, 2019
by
Paul
Browse files
Sort while creating partitions
parent
6e9142b5
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
69 additions
and
36 deletions
+69
-36
src/program.cpp
src/program.cpp
+0
-2
src/schedule.cpp
src/schedule.cpp
+56
-24
test/schedule_test.cpp
test/schedule_test.cpp
+13
-10
No files found.
src/program.cpp
View file @
4008675f
...
...
@@ -104,11 +104,9 @@ instruction_ref program::insert_instruction(instruction_ref ins,
args
.
begin
(),
args
.
end
(),
[
&
](
instruction_ref
x
)
{
return
has_instruction
(
x
);
})
&&
"Argument is not an exisiting instruction"
);
assert
(
not
starts_with
(
op
.
name
(),
"@"
));
// TODO: Use move
shape
r
=
compute_shape
(
op
,
args
);
auto
result
=
impl
->
instructions
.
insert
(
ins
,
{
op
,
r
,
std
::
move
(
args
)});
instruction
::
backreference
(
result
);
// assert(result->inputs() == args);
assert
(
result
->
valid
(
begin
()));
return
result
;
}
...
...
src/schedule.cpp
View file @
4008675f
...
...
@@ -50,6 +50,39 @@ struct stream_info
})(
last
);
}
std
::
vector
<
instruction_ref
>::
iterator
sort_args
(
std
::
vector
<
instruction_ref
>&
args
)
{
const
std
::
size_t
min_partition_threshold
=
2
;
auto
compare
=
by
(
std
::
less
<>
{},
[
&
](
auto
x
)
{
return
std
::
make_tuple
(
this
->
weights
[
x
],
x
->
inputs
().
size
());
});
if
(
args
.
size
()
<
2
)
{
return
args
.
end
();
}
else
if
(
args
.
size
()
==
2
)
{
auto
w1
=
this
->
weights
[
args
[
0
]];
auto
w2
=
this
->
weights
[
args
[
1
]];
if
(
std
::
make_tuple
(
w1
,
args
[
0
]
->
inputs
().
size
())
>
std
::
make_tuple
(
w2
,
args
[
1
]
->
inputs
().
size
()))
{
std
::
swap
(
args
[
0
],
args
[
1
]);
std
::
swap
(
w1
,
w2
);
}
if
(
w1
>
min_partition_threshold
)
return
args
.
begin
();
if
(
w2
>
min_partition_threshold
)
return
args
.
begin
()
+
1
;
return
args
.
end
();
}
std
::
sort
(
args
.
begin
(),
args
.
end
(),
compare
);
return
std
::
upper_bound
(
args
.
begin
(),
args
.
end
(),
min_partition_threshold
,
[
&
](
std
::
size_t
w
,
auto
i
)
{
return
w
<
this
->
weights
[
i
];
});
}
struct
partition
{
std
::
size_t
weight
=
0
;
...
...
@@ -64,22 +97,24 @@ struct stream_info
void
assign_streams
(
program
&
p
,
std
::
size_t
n
)
{
const
std
::
size_t
min_partition_threshold
=
2
;
partition
critical
;
std
::
unordered_map
<
instruction_ref
,
std
::
deque
<
partition
>>
partitions
;
partitions
.
reserve
(
weights
.
size
());
fix
([
&
](
auto
self
,
auto
ins
,
auto
&
part
)
{
// If weight is zero then stop
if
(
this
->
weights
[
ins
]
==
0
)
if
(
contains
(
partitions
,
ins
))
return
;
partitions
[
ins
];
part
.
add
(
ins
,
this
->
iweights
[
ins
]);
auto
max_it
=
std
::
max_element
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
by
(
std
::
less
<>
{},
index_of
(
this
->
weights
)));
for
(
auto
i
:
ins
->
inputs
())
auto
args
=
ins
->
inputs
();
auto
threshold_it
=
sort_args
(
args
);
for
(
auto
i
:
range
(
args
.
begin
(),
threshold_it
))
{
const
auto
weight
=
this
->
weights
[
i
];
if
(
i
==
*
max_it
or
weight
<=
min_partition_threshold
)
self
(
i
,
part
);
}
for
(
auto
i
:
range
(
threshold_it
,
args
.
end
()))
{
if
(
i
==
args
.
back
())
{
self
(
i
,
part
);
}
...
...
@@ -89,6 +124,8 @@ struct stream_info
self
(
i
,
partitions
[
ins
].
back
());
}
}
// Sort instructions
p
.
move_instruction
(
ins
,
p
.
end
());
})(
std
::
prev
(
p
.
end
()),
critical
);
// Set the critical partition to stream 0
...
...
@@ -233,6 +270,8 @@ struct stream_info
{
std
::
unordered_map
<
instruction_ref
,
std
::
vector
<
std
::
vector
<
instruction_ref
>>>
result
;
std
::
unordered_map
<
instruction_ref
,
std
::
unordered_set
<
instruction_ref
>>
merge_from
;
result
.
reserve
(
p
.
size
());
merge_from
.
reserve
(
p
.
size
());
for
(
auto
ins
:
reverse_iterator_for
(
p
))
{
for
(
auto
&&
arg
:
ins
->
outputs
())
...
...
@@ -254,7 +293,7 @@ struct stream_info
auto
&&
r
=
result
[
merge
][
stream
];
r
.
push_back
(
ins
);
// Copy inputs if they dont have a stream(and are not a builtin and context
// free) Inputs without a stream can have a implicit dependency
// free)
.
Inputs without a stream can have a implicit dependency
std
::
copy_if
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
std
::
back_inserter
(
r
),
...
...
@@ -277,19 +316,6 @@ void schedule::apply(program& p) const
si
.
accumulate_weights
(
last
,
model
);
si
.
assign_streams
(
p
,
model
.
concurrency
());
// Topo sort
fix
([
&
](
auto
self
,
auto
ins
)
{
auto
args
=
ins
->
inputs
();
std
::
sort
(
args
.
begin
(),
args
.
end
(),
by
(
std
::
less
<>
{},
[
&
](
auto
x
)
{
return
std
::
make_tuple
(
si
.
weights
[
x
],
x
->
inputs
().
size
());
}));
for
(
auto
i
:
args
)
{
p
.
move_instruction
(
i
,
p
.
begin
());
self
(
i
);
}
})(
last
);
if
(
enabled
(
MIGRAPHX_TRACE_COMPILE
{}))
{
p
.
annotate
(
std
::
cout
,
[
&
](
auto
ins
)
{
...
...
@@ -308,10 +334,12 @@ void schedule::apply(program& p) const
}
// Schedule instructions
std
::
unordered_map
<
instruction_ref
,
std
::
size_t
>
ins2wait
;
std
::
size_t
wait_id
=
0
;
std
::
unordered_map
<
instruction_ref
,
std
::
size_t
>
ins2wait
;
std
::
unordered_map
<
std
::
size_t
,
std
::
unordered_set
<
std
::
size_t
>>
waited_for
;
std
::
unordered_map
<
instruction_ref
,
std
::
unordered_set
<
std
::
size_t
>>
ins2waited
;
ins2wait
.
reserve
(
p
.
size
());
ins2waited
.
reserve
(
p
.
size
());
for
(
auto
ins
:
iterator_for
(
p
))
{
// Only schedule instructions that have a stream
...
...
@@ -364,6 +392,10 @@ void schedule::apply(program& p) const
dfor
(
merge
.
second
.
size
(),
merge
.
second
.
size
())([
&
](
auto
i
,
auto
j
)
{
if
(
i
==
j
)
return
;
if
(
merge
.
second
[
i
].
empty
())
return
;
if
(
merge
.
second
[
j
].
empty
())
return
;
for
(
auto
ins1
:
merge
.
second
[
i
])
{
auto
args
=
merge
.
second
[
j
];
...
...
test/schedule_test.cpp
View file @
4008675f
...
...
@@ -521,17 +521,19 @@ TEST_CASE(seq_merge)
p
.
compile
(
t
);
EXPECT
(
not
t
.
has_stream
(
one
));
EXPECT
(
t
.
get_stream
(
i1
)
==
2
);
EXPECT
(
t
.
get_stream
(
i1
)
!=
t
.
get_stream
(
i2
));
EXPECT
(
t
.
get_stream
(
i1
)
!=
t
.
get_stream
(
c1
.
back
()));
for
(
auto
ins
:
c1
)
EXPECT
(
t
.
get_stream
(
ins
)
==
3
);
EXPECT
(
t
.
get_stream
(
binary1
)
==
3
);
EXPECT
(
t
.
get_stream
(
ins
)
==
t
.
get_stream
(
c1
.
back
())
);
EXPECT
(
t
.
get_stream
(
binary1
)
==
t
.
get_stream
(
c1
.
back
())
);
EXPECT
(
get_wait_for
(
binary1
)
==
get_wait_for
(
t
.
get_stream
(
binary1
),
{
t
.
get_stream
(
c1
.
back
()),
t
.
get_stream
(
i1
)}));
check_conflicts
(
p
,
{
c1
,
{
i1
}});
EXPECT
(
t
.
get_stream
(
i2
)
=
=
3
);
EXPECT
(
t
.
get_stream
(
i2
)
!
=
t
.
get_stream
(
c2
.
back
())
);
for
(
auto
ins
:
c2
)
EXPECT
(
t
.
get_stream
(
ins
)
==
0
);
EXPECT
(
t
.
get_stream
(
ins
)
==
t
.
get_stream
(
c2
.
back
()));
EXPECT
(
t
.
get_stream
(
c1
.
back
())
!=
t
.
get_stream
(
c2
.
back
()));
EXPECT
(
t
.
get_stream
(
binary2
)
==
0
);
EXPECT
(
get_wait_for
(
binary2
)
==
get_wait_for
(
t
.
get_stream
(
binary2
),
{
t
.
get_stream
(
c2
.
back
()),
t
.
get_stream
(
i2
)}));
...
...
@@ -559,7 +561,7 @@ TEST_CASE(par_merge)
EXPECT
(
not
t
.
has_stream
(
one
));
EXPECT
(
t
.
get_stream
(
binary3
)
==
0
);
EXPECT
(
t
.
get_stream
(
i1
)
=
=
2
);
EXPECT
(
t
.
get_stream
(
i1
)
!
=
t
.
get_stream
(
i2
)
);
for
(
auto
ins
:
c1
)
EXPECT
(
t
.
get_stream
(
ins
)
==
0
);
EXPECT
(
t
.
get_stream
(
binary1
)
==
0
);
...
...
@@ -567,7 +569,6 @@ TEST_CASE(par_merge)
get_wait_for
(
t
.
get_stream
(
binary1
),
{
t
.
get_stream
(
c1
.
back
()),
t
.
get_stream
(
i1
)}));
check_conflicts
(
p
,
{
c1
,
{
i1
}});
EXPECT
(
t
.
get_stream
(
i2
)
==
1
);
for
(
auto
ins
:
c2
)
EXPECT
(
t
.
get_stream
(
ins
)
==
3
);
EXPECT
(
t
.
get_stream
(
binary2
)
==
3
);
...
...
@@ -684,7 +685,8 @@ TEST_CASE(inner_split1)
auto
output
=
p
.
add_instruction
(
nary_op
{},
i1
,
s1
,
s2
);
p
.
compile
(
t
);
EXPECT
(
not
t
.
has_stream
(
one
));
EXPECT
(
t
.
get_stream
(
i1
)
==
3
);
EXPECT
(
t
.
get_stream
(
i1
)
!=
t
.
get_stream
(
s1
));
EXPECT
(
t
.
get_stream
(
i1
)
!=
t
.
get_stream
(
s2
));
for
(
auto
ins
:
c1
)
EXPECT
(
t
.
get_stream
(
ins
)
!=
t
.
get_stream
(
i1
));
EXPECT
(
t
.
get_stream
(
s1
)
!=
t
.
get_stream
(
s2
));
...
...
@@ -709,7 +711,8 @@ TEST_CASE(inner_split2)
auto
output
=
p
.
add_instruction
(
nary_op
{},
i1
,
s1
.
back
(),
s2
.
back
());
p
.
compile
(
t
);
EXPECT
(
not
t
.
has_stream
(
one
));
EXPECT
(
t
.
get_stream
(
i1
)
==
2
);
EXPECT
(
t
.
get_stream
(
i1
)
!=
t
.
get_stream
(
s1
.
back
()));
EXPECT
(
t
.
get_stream
(
i1
)
!=
t
.
get_stream
(
s2
.
back
()));
for
(
auto
ins
:
c1
)
EXPECT
(
t
.
get_stream
(
ins
)
!=
t
.
get_stream
(
i1
));
EXPECT
(
t
.
get_stream
(
s1
.
back
())
!=
t
.
get_stream
(
s2
.
back
()));
...
...
@@ -735,7 +738,7 @@ TEST_CASE(inception_resnet)
auto
output
=
p
.
add_instruction
(
nary_op
{},
binary
,
input
);
p
.
compile
(
t
);
EXPECT
(
not
t
.
has_stream
(
one
));
EXPECT
(
t
.
get_stream
(
i1
)
=
=
2
);
EXPECT
(
t
.
get_stream
(
i1
)
!
=
0
);
for
(
auto
ins
:
c1
)
EXPECT
(
t
.
get_stream
(
ins
)
==
0
);
EXPECT
(
t
.
get_stream
(
binary
)
==
0
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment