Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
4008675f
Commit
4008675f
authored
Mar 11, 2019
by
Paul
Browse files
Sort while creating partitions
parent
6e9142b5
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
69 additions
and
36 deletions
+69
-36
src/program.cpp
src/program.cpp
+0
-2
src/schedule.cpp
src/schedule.cpp
+56
-24
test/schedule_test.cpp
test/schedule_test.cpp
+13
-10
No files found.
src/program.cpp
View file @
4008675f
...
@@ -104,11 +104,9 @@ instruction_ref program::insert_instruction(instruction_ref ins,
...
@@ -104,11 +104,9 @@ instruction_ref program::insert_instruction(instruction_ref ins,
args
.
begin
(),
args
.
end
(),
[
&
](
instruction_ref
x
)
{
return
has_instruction
(
x
);
})
&&
args
.
begin
(),
args
.
end
(),
[
&
](
instruction_ref
x
)
{
return
has_instruction
(
x
);
})
&&
"Argument is not an exisiting instruction"
);
"Argument is not an exisiting instruction"
);
assert
(
not
starts_with
(
op
.
name
(),
"@"
));
assert
(
not
starts_with
(
op
.
name
(),
"@"
));
// TODO: Use move
shape
r
=
compute_shape
(
op
,
args
);
shape
r
=
compute_shape
(
op
,
args
);
auto
result
=
impl
->
instructions
.
insert
(
ins
,
{
op
,
r
,
std
::
move
(
args
)});
auto
result
=
impl
->
instructions
.
insert
(
ins
,
{
op
,
r
,
std
::
move
(
args
)});
instruction
::
backreference
(
result
);
instruction
::
backreference
(
result
);
// assert(result->inputs() == args);
assert
(
result
->
valid
(
begin
()));
assert
(
result
->
valid
(
begin
()));
return
result
;
return
result
;
}
}
...
...
src/schedule.cpp
View file @
4008675f
...
@@ -50,6 +50,39 @@ struct stream_info
...
@@ -50,6 +50,39 @@ struct stream_info
})(
last
);
})(
last
);
}
}
std
::
vector
<
instruction_ref
>::
iterator
sort_args
(
std
::
vector
<
instruction_ref
>&
args
)
{
const
std
::
size_t
min_partition_threshold
=
2
;
auto
compare
=
by
(
std
::
less
<>
{},
[
&
](
auto
x
)
{
return
std
::
make_tuple
(
this
->
weights
[
x
],
x
->
inputs
().
size
());
});
if
(
args
.
size
()
<
2
)
{
return
args
.
end
();
}
else
if
(
args
.
size
()
==
2
)
{
auto
w1
=
this
->
weights
[
args
[
0
]];
auto
w2
=
this
->
weights
[
args
[
1
]];
if
(
std
::
make_tuple
(
w1
,
args
[
0
]
->
inputs
().
size
())
>
std
::
make_tuple
(
w2
,
args
[
1
]
->
inputs
().
size
()))
{
std
::
swap
(
args
[
0
],
args
[
1
]);
std
::
swap
(
w1
,
w2
);
}
if
(
w1
>
min_partition_threshold
)
return
args
.
begin
();
if
(
w2
>
min_partition_threshold
)
return
args
.
begin
()
+
1
;
return
args
.
end
();
}
std
::
sort
(
args
.
begin
(),
args
.
end
(),
compare
);
return
std
::
upper_bound
(
args
.
begin
(),
args
.
end
(),
min_partition_threshold
,
[
&
](
std
::
size_t
w
,
auto
i
)
{
return
w
<
this
->
weights
[
i
];
});
}
struct
partition
struct
partition
{
{
std
::
size_t
weight
=
0
;
std
::
size_t
weight
=
0
;
...
@@ -64,22 +97,24 @@ struct stream_info
...
@@ -64,22 +97,24 @@ struct stream_info
void
assign_streams
(
program
&
p
,
std
::
size_t
n
)
void
assign_streams
(
program
&
p
,
std
::
size_t
n
)
{
{
const
std
::
size_t
min_partition_threshold
=
2
;
partition
critical
;
partition
critical
;
std
::
unordered_map
<
instruction_ref
,
std
::
deque
<
partition
>>
partitions
;
std
::
unordered_map
<
instruction_ref
,
std
::
deque
<
partition
>>
partitions
;
partitions
.
reserve
(
weights
.
size
());
fix
([
&
](
auto
self
,
auto
ins
,
auto
&
part
)
{
fix
([
&
](
auto
self
,
auto
ins
,
auto
&
part
)
{
// If weight is zero then stop
if
(
contains
(
partitions
,
ins
))
if
(
this
->
weights
[
ins
]
==
0
)
return
;
return
;
partitions
[
ins
];
part
.
add
(
ins
,
this
->
iweights
[
ins
]);
part
.
add
(
ins
,
this
->
iweights
[
ins
]);
auto
max_it
=
std
::
max_element
(
ins
->
inputs
().
begin
(),
auto
args
=
ins
->
inputs
();
ins
->
inputs
().
end
(),
auto
threshold_it
=
sort_args
(
args
);
by
(
std
::
less
<>
{},
index_of
(
this
->
weights
)));
for
(
auto
i
:
range
(
args
.
begin
(),
threshold_it
))
for
(
auto
i
:
ins
->
inputs
())
{
{
const
auto
weight
=
this
->
weights
[
i
];
self
(
i
,
part
);
if
(
i
==
*
max_it
or
weight
<=
min_partition_threshold
)
}
for
(
auto
i
:
range
(
threshold_it
,
args
.
end
()))
{
if
(
i
==
args
.
back
())
{
{
self
(
i
,
part
);
self
(
i
,
part
);
}
}
...
@@ -89,6 +124,8 @@ struct stream_info
...
@@ -89,6 +124,8 @@ struct stream_info
self
(
i
,
partitions
[
ins
].
back
());
self
(
i
,
partitions
[
ins
].
back
());
}
}
}
}
// Sort instructions
p
.
move_instruction
(
ins
,
p
.
end
());
})(
std
::
prev
(
p
.
end
()),
critical
);
})(
std
::
prev
(
p
.
end
()),
critical
);
// Set the critical partition to stream 0
// Set the critical partition to stream 0
...
@@ -233,6 +270,8 @@ struct stream_info
...
@@ -233,6 +270,8 @@ struct stream_info
{
{
std
::
unordered_map
<
instruction_ref
,
std
::
vector
<
std
::
vector
<
instruction_ref
>>>
result
;
std
::
unordered_map
<
instruction_ref
,
std
::
vector
<
std
::
vector
<
instruction_ref
>>>
result
;
std
::
unordered_map
<
instruction_ref
,
std
::
unordered_set
<
instruction_ref
>>
merge_from
;
std
::
unordered_map
<
instruction_ref
,
std
::
unordered_set
<
instruction_ref
>>
merge_from
;
result
.
reserve
(
p
.
size
());
merge_from
.
reserve
(
p
.
size
());
for
(
auto
ins
:
reverse_iterator_for
(
p
))
for
(
auto
ins
:
reverse_iterator_for
(
p
))
{
{
for
(
auto
&&
arg
:
ins
->
outputs
())
for
(
auto
&&
arg
:
ins
->
outputs
())
...
@@ -254,7 +293,7 @@ struct stream_info
...
@@ -254,7 +293,7 @@ struct stream_info
auto
&&
r
=
result
[
merge
][
stream
];
auto
&&
r
=
result
[
merge
][
stream
];
r
.
push_back
(
ins
);
r
.
push_back
(
ins
);
// Copy inputs if they dont have a stream(and are not a builtin and context
// Copy inputs if they dont have a stream(and are not a builtin and context
// free) Inputs without a stream can have a implicit dependency
// free)
.
Inputs without a stream can have a implicit dependency
std
::
copy_if
(
ins
->
inputs
().
begin
(),
std
::
copy_if
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
ins
->
inputs
().
end
(),
std
::
back_inserter
(
r
),
std
::
back_inserter
(
r
),
...
@@ -277,19 +316,6 @@ void schedule::apply(program& p) const
...
@@ -277,19 +316,6 @@ void schedule::apply(program& p) const
si
.
accumulate_weights
(
last
,
model
);
si
.
accumulate_weights
(
last
,
model
);
si
.
assign_streams
(
p
,
model
.
concurrency
());
si
.
assign_streams
(
p
,
model
.
concurrency
());
// Topo sort
fix
([
&
](
auto
self
,
auto
ins
)
{
auto
args
=
ins
->
inputs
();
std
::
sort
(
args
.
begin
(),
args
.
end
(),
by
(
std
::
less
<>
{},
[
&
](
auto
x
)
{
return
std
::
make_tuple
(
si
.
weights
[
x
],
x
->
inputs
().
size
());
}));
for
(
auto
i
:
args
)
{
p
.
move_instruction
(
i
,
p
.
begin
());
self
(
i
);
}
})(
last
);
if
(
enabled
(
MIGRAPHX_TRACE_COMPILE
{}))
if
(
enabled
(
MIGRAPHX_TRACE_COMPILE
{}))
{
{
p
.
annotate
(
std
::
cout
,
[
&
](
auto
ins
)
{
p
.
annotate
(
std
::
cout
,
[
&
](
auto
ins
)
{
...
@@ -308,10 +334,12 @@ void schedule::apply(program& p) const
...
@@ -308,10 +334,12 @@ void schedule::apply(program& p) const
}
}
// Schedule instructions
// Schedule instructions
std
::
unordered_map
<
instruction_ref
,
std
::
size_t
>
ins2wait
;
std
::
size_t
wait_id
=
0
;
std
::
size_t
wait_id
=
0
;
std
::
unordered_map
<
instruction_ref
,
std
::
size_t
>
ins2wait
;
std
::
unordered_map
<
std
::
size_t
,
std
::
unordered_set
<
std
::
size_t
>>
waited_for
;
std
::
unordered_map
<
std
::
size_t
,
std
::
unordered_set
<
std
::
size_t
>>
waited_for
;
std
::
unordered_map
<
instruction_ref
,
std
::
unordered_set
<
std
::
size_t
>>
ins2waited
;
std
::
unordered_map
<
instruction_ref
,
std
::
unordered_set
<
std
::
size_t
>>
ins2waited
;
ins2wait
.
reserve
(
p
.
size
());
ins2waited
.
reserve
(
p
.
size
());
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
p
))
{
{
// Only schedule instructions that have a stream
// Only schedule instructions that have a stream
...
@@ -364,6 +392,10 @@ void schedule::apply(program& p) const
...
@@ -364,6 +392,10 @@ void schedule::apply(program& p) const
dfor
(
merge
.
second
.
size
(),
merge
.
second
.
size
())([
&
](
auto
i
,
auto
j
)
{
dfor
(
merge
.
second
.
size
(),
merge
.
second
.
size
())([
&
](
auto
i
,
auto
j
)
{
if
(
i
==
j
)
if
(
i
==
j
)
return
;
return
;
if
(
merge
.
second
[
i
].
empty
())
return
;
if
(
merge
.
second
[
j
].
empty
())
return
;
for
(
auto
ins1
:
merge
.
second
[
i
])
for
(
auto
ins1
:
merge
.
second
[
i
])
{
{
auto
args
=
merge
.
second
[
j
];
auto
args
=
merge
.
second
[
j
];
...
...
test/schedule_test.cpp
View file @
4008675f
...
@@ -521,17 +521,19 @@ TEST_CASE(seq_merge)
...
@@ -521,17 +521,19 @@ TEST_CASE(seq_merge)
p
.
compile
(
t
);
p
.
compile
(
t
);
EXPECT
(
not
t
.
has_stream
(
one
));
EXPECT
(
not
t
.
has_stream
(
one
));
EXPECT
(
t
.
get_stream
(
i1
)
==
2
);
EXPECT
(
t
.
get_stream
(
i1
)
!=
t
.
get_stream
(
i2
));
EXPECT
(
t
.
get_stream
(
i1
)
!=
t
.
get_stream
(
c1
.
back
()));
for
(
auto
ins
:
c1
)
for
(
auto
ins
:
c1
)
EXPECT
(
t
.
get_stream
(
ins
)
==
3
);
EXPECT
(
t
.
get_stream
(
ins
)
==
t
.
get_stream
(
c1
.
back
())
);
EXPECT
(
t
.
get_stream
(
binary1
)
==
3
);
EXPECT
(
t
.
get_stream
(
binary1
)
==
t
.
get_stream
(
c1
.
back
())
);
EXPECT
(
get_wait_for
(
binary1
)
==
EXPECT
(
get_wait_for
(
binary1
)
==
get_wait_for
(
t
.
get_stream
(
binary1
),
{
t
.
get_stream
(
c1
.
back
()),
t
.
get_stream
(
i1
)}));
get_wait_for
(
t
.
get_stream
(
binary1
),
{
t
.
get_stream
(
c1
.
back
()),
t
.
get_stream
(
i1
)}));
check_conflicts
(
p
,
{
c1
,
{
i1
}});
check_conflicts
(
p
,
{
c1
,
{
i1
}});
EXPECT
(
t
.
get_stream
(
i2
)
=
=
3
);
EXPECT
(
t
.
get_stream
(
i2
)
!
=
t
.
get_stream
(
c2
.
back
())
);
for
(
auto
ins
:
c2
)
for
(
auto
ins
:
c2
)
EXPECT
(
t
.
get_stream
(
ins
)
==
0
);
EXPECT
(
t
.
get_stream
(
ins
)
==
t
.
get_stream
(
c2
.
back
()));
EXPECT
(
t
.
get_stream
(
c1
.
back
())
!=
t
.
get_stream
(
c2
.
back
()));
EXPECT
(
t
.
get_stream
(
binary2
)
==
0
);
EXPECT
(
t
.
get_stream
(
binary2
)
==
0
);
EXPECT
(
get_wait_for
(
binary2
)
==
EXPECT
(
get_wait_for
(
binary2
)
==
get_wait_for
(
t
.
get_stream
(
binary2
),
{
t
.
get_stream
(
c2
.
back
()),
t
.
get_stream
(
i2
)}));
get_wait_for
(
t
.
get_stream
(
binary2
),
{
t
.
get_stream
(
c2
.
back
()),
t
.
get_stream
(
i2
)}));
...
@@ -559,7 +561,7 @@ TEST_CASE(par_merge)
...
@@ -559,7 +561,7 @@ TEST_CASE(par_merge)
EXPECT
(
not
t
.
has_stream
(
one
));
EXPECT
(
not
t
.
has_stream
(
one
));
EXPECT
(
t
.
get_stream
(
binary3
)
==
0
);
EXPECT
(
t
.
get_stream
(
binary3
)
==
0
);
EXPECT
(
t
.
get_stream
(
i1
)
=
=
2
);
EXPECT
(
t
.
get_stream
(
i1
)
!
=
t
.
get_stream
(
i2
)
);
for
(
auto
ins
:
c1
)
for
(
auto
ins
:
c1
)
EXPECT
(
t
.
get_stream
(
ins
)
==
0
);
EXPECT
(
t
.
get_stream
(
ins
)
==
0
);
EXPECT
(
t
.
get_stream
(
binary1
)
==
0
);
EXPECT
(
t
.
get_stream
(
binary1
)
==
0
);
...
@@ -567,7 +569,6 @@ TEST_CASE(par_merge)
...
@@ -567,7 +569,6 @@ TEST_CASE(par_merge)
get_wait_for
(
t
.
get_stream
(
binary1
),
{
t
.
get_stream
(
c1
.
back
()),
t
.
get_stream
(
i1
)}));
get_wait_for
(
t
.
get_stream
(
binary1
),
{
t
.
get_stream
(
c1
.
back
()),
t
.
get_stream
(
i1
)}));
check_conflicts
(
p
,
{
c1
,
{
i1
}});
check_conflicts
(
p
,
{
c1
,
{
i1
}});
EXPECT
(
t
.
get_stream
(
i2
)
==
1
);
for
(
auto
ins
:
c2
)
for
(
auto
ins
:
c2
)
EXPECT
(
t
.
get_stream
(
ins
)
==
3
);
EXPECT
(
t
.
get_stream
(
ins
)
==
3
);
EXPECT
(
t
.
get_stream
(
binary2
)
==
3
);
EXPECT
(
t
.
get_stream
(
binary2
)
==
3
);
...
@@ -684,7 +685,8 @@ TEST_CASE(inner_split1)
...
@@ -684,7 +685,8 @@ TEST_CASE(inner_split1)
auto
output
=
p
.
add_instruction
(
nary_op
{},
i1
,
s1
,
s2
);
auto
output
=
p
.
add_instruction
(
nary_op
{},
i1
,
s1
,
s2
);
p
.
compile
(
t
);
p
.
compile
(
t
);
EXPECT
(
not
t
.
has_stream
(
one
));
EXPECT
(
not
t
.
has_stream
(
one
));
EXPECT
(
t
.
get_stream
(
i1
)
==
3
);
EXPECT
(
t
.
get_stream
(
i1
)
!=
t
.
get_stream
(
s1
));
EXPECT
(
t
.
get_stream
(
i1
)
!=
t
.
get_stream
(
s2
));
for
(
auto
ins
:
c1
)
for
(
auto
ins
:
c1
)
EXPECT
(
t
.
get_stream
(
ins
)
!=
t
.
get_stream
(
i1
));
EXPECT
(
t
.
get_stream
(
ins
)
!=
t
.
get_stream
(
i1
));
EXPECT
(
t
.
get_stream
(
s1
)
!=
t
.
get_stream
(
s2
));
EXPECT
(
t
.
get_stream
(
s1
)
!=
t
.
get_stream
(
s2
));
...
@@ -709,7 +711,8 @@ TEST_CASE(inner_split2)
...
@@ -709,7 +711,8 @@ TEST_CASE(inner_split2)
auto
output
=
p
.
add_instruction
(
nary_op
{},
i1
,
s1
.
back
(),
s2
.
back
());
auto
output
=
p
.
add_instruction
(
nary_op
{},
i1
,
s1
.
back
(),
s2
.
back
());
p
.
compile
(
t
);
p
.
compile
(
t
);
EXPECT
(
not
t
.
has_stream
(
one
));
EXPECT
(
not
t
.
has_stream
(
one
));
EXPECT
(
t
.
get_stream
(
i1
)
==
2
);
EXPECT
(
t
.
get_stream
(
i1
)
!=
t
.
get_stream
(
s1
.
back
()));
EXPECT
(
t
.
get_stream
(
i1
)
!=
t
.
get_stream
(
s2
.
back
()));
for
(
auto
ins
:
c1
)
for
(
auto
ins
:
c1
)
EXPECT
(
t
.
get_stream
(
ins
)
!=
t
.
get_stream
(
i1
));
EXPECT
(
t
.
get_stream
(
ins
)
!=
t
.
get_stream
(
i1
));
EXPECT
(
t
.
get_stream
(
s1
.
back
())
!=
t
.
get_stream
(
s2
.
back
()));
EXPECT
(
t
.
get_stream
(
s1
.
back
())
!=
t
.
get_stream
(
s2
.
back
()));
...
@@ -735,7 +738,7 @@ TEST_CASE(inception_resnet)
...
@@ -735,7 +738,7 @@ TEST_CASE(inception_resnet)
auto
output
=
p
.
add_instruction
(
nary_op
{},
binary
,
input
);
auto
output
=
p
.
add_instruction
(
nary_op
{},
binary
,
input
);
p
.
compile
(
t
);
p
.
compile
(
t
);
EXPECT
(
not
t
.
has_stream
(
one
));
EXPECT
(
not
t
.
has_stream
(
one
));
EXPECT
(
t
.
get_stream
(
i1
)
=
=
2
);
EXPECT
(
t
.
get_stream
(
i1
)
!
=
0
);
for
(
auto
ins
:
c1
)
for
(
auto
ins
:
c1
)
EXPECT
(
t
.
get_stream
(
ins
)
==
0
);
EXPECT
(
t
.
get_stream
(
ins
)
==
0
);
EXPECT
(
t
.
get_stream
(
binary
)
==
0
);
EXPECT
(
t
.
get_stream
(
binary
)
==
0
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment