Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
6b4175e8
Commit
6b4175e8
authored
Mar 08, 2019
by
Paul
Browse files
Add inception like test
parent
cec7544c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
143 additions
and
2 deletions
+143
-2
src/include/migraphx/streamutils.hpp
src/include/migraphx/streamutils.hpp
+6
-1
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+1
-1
test/schedule_test.cpp
test/schedule_test.cpp
+136
-0
No files found.
src/include/migraphx/streamutils.hpp
View file @
6b4175e8
...
@@ -36,6 +36,11 @@ inline stream_range_container<Range> stream_range(const Range& r)
...
@@ -36,6 +36,11 @@ inline stream_range_container<Range> stream_range(const Range& r)
namespace
detail
{
namespace
detail
{
inline
void
stream_write_value_impl
(
rank
<
2
>
,
std
::
ostream
&
os
,
const
std
::
string
&
x
)
{
os
<<
x
;
}
template
<
class
Range
>
template
<
class
Range
>
auto
stream_write_value_impl
(
rank
<
1
>
,
std
::
ostream
&
os
,
const
Range
&
r
)
auto
stream_write_value_impl
(
rank
<
1
>
,
std
::
ostream
&
os
,
const
Range
&
r
)
->
decltype
(
r
.
begin
(),
r
.
end
(),
void
())
->
decltype
(
r
.
begin
(),
r
.
end
(),
void
())
...
@@ -53,7 +58,7 @@ void stream_write_value_impl(rank<0>, std::ostream& os, const T& x)
...
@@ -53,7 +58,7 @@ void stream_write_value_impl(rank<0>, std::ostream& os, const T& x)
template
<
class
T
>
template
<
class
T
>
void
stream_write_value
(
std
::
ostream
&
os
,
const
T
&
x
)
void
stream_write_value
(
std
::
ostream
&
os
,
const
T
&
x
)
{
{
detail
::
stream_write_value_impl
(
rank
<
1
>
{},
os
,
x
);
detail
::
stream_write_value_impl
(
rank
<
2
>
{},
os
,
x
);
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/targets/gpu/target.cpp
View file @
6b4175e8
...
@@ -54,7 +54,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
...
@@ -54,7 +54,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
dead_code_elimination
{},
dead_code_elimination
{},
write_literals
{
&
ctx
},
write_literals
{
&
ctx
},
schedule
{
gpu
::
schedule_model
{
ctx
.
get_current_device
().
nstreams
()}},
schedule
{
gpu
::
schedule_model
{
ctx
.
get_current_device
().
nstreams
()}},
//
memory_coloring{"hip::allocate"},
memory_coloring
{
"hip::allocate"
},
dead_code_elimination
{},
dead_code_elimination
{},
// eliminate_workspace{},
// eliminate_workspace{},
eliminate_allocation
{
"hip::allocate"
},
eliminate_allocation
{
"hip::allocate"
},
...
...
test/schedule_test.cpp
View file @
6b4175e8
...
@@ -30,6 +30,12 @@ struct unary_op
...
@@ -30,6 +30,12 @@ struct unary_op
struct
nary_op
struct
nary_op
{
{
std
::
string
comment
=
""
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
migraphx
::
pack
(
f
(
self
.
comment
,
"comment"
));
}
std
::
string
name
()
const
{
return
"nary"
;
}
std
::
string
name
()
const
{
return
"nary"
;
}
migraphx
::
argument
migraphx
::
argument
compute
(
migraphx
::
context
&
,
const
migraphx
::
shape
&
,
std
::
vector
<
migraphx
::
argument
>
args
)
const
compute
(
migraphx
::
context
&
,
const
migraphx
::
shape
&
,
std
::
vector
<
migraphx
::
argument
>
args
)
const
...
@@ -119,6 +125,15 @@ struct schedule_target
...
@@ -119,6 +125,15 @@ struct schedule_target
std
::
size_t
get_stream
(
migraphx
::
instruction_ref
ins
)
{
return
model
.
ins2stream
->
at
(
ins
);
}
std
::
size_t
get_stream
(
migraphx
::
instruction_ref
ins
)
{
return
model
.
ins2stream
->
at
(
ins
);
}
std
::
vector
<
std
::
size_t
>
get_streams
(
std
::
vector
<
migraphx
::
instruction_ref
>
inss
)
{
std
::
vector
<
std
::
size_t
>
result
;
std
::
transform
(
inss
.
begin
(),
inss
.
end
(),
std
::
back_inserter
(
result
),
[
&
](
auto
ins
)
{
return
this
->
get_stream
(
ins
);
});
return
result
;
}
bool
has_stream
(
migraphx
::
instruction_ref
ins
)
{
return
model
.
ins2stream
->
count
(
ins
)
>
0
;
}
bool
has_stream
(
migraphx
::
instruction_ref
ins
)
{
return
model
.
ins2stream
->
count
(
ins
)
>
0
;
}
};
};
...
@@ -565,4 +580,125 @@ TEST_CASE(par_merge_multi_entry)
...
@@ -565,4 +580,125 @@ TEST_CASE(par_merge_multi_entry)
EXPECT
(
check_conflicts
(
p
,
binary1
,
binary2
));
EXPECT
(
check_conflicts
(
p
,
binary1
,
binary2
));
check_conflicts
(
p
,
{
c1
,
{
i1
},
c2
,
{
i2
}});
check_conflicts
(
p
,
{
c1
,
{
i1
},
c2
,
{
i2
}});
}
}
TEST_CASE
(
inception1
)
{
schedule_target
t
{};
migraphx
::
program
p
;
auto
i1
=
p
.
add_literal
(
0
);
auto
i2
=
p
.
add_literal
(
1
);
auto
i3
=
p
.
add_literal
(
1
);
auto
i4
=
p
.
add_literal
(
2
);
auto
i7
=
p
.
add_instruction
(
nary_op
{
"i7"
},
i1
,
i4
,
i3
,
i2
);
auto
i8
=
p
.
add_literal
(
2
);
auto
i9
=
p
.
add_instruction
(
migraphx
::
op
::
identity
{},
i8
);
auto
i10
=
p
.
add_literal
(
1
);
auto
i11
=
p
.
add_instruction
(
nary_op
{
"i11"
},
i7
,
i9
,
i10
);
auto
i12
=
p
.
add_literal
(
2
);
auto
i13
=
p
.
add_instruction
(
migraphx
::
op
::
identity
{},
i12
);
auto
i14
=
p
.
add_literal
(
1
);
auto
i15
=
p
.
add_literal
(
1
);
auto
i16
=
p
.
add_literal
(
2
);
auto
i17
=
p
.
add_instruction
(
nary_op
{
"i17"
},
i11
,
i16
,
i15
,
i13
,
i14
);
auto
i18
=
p
.
add_literal
(
2
);
auto
i19
=
p
.
add_instruction
(
migraphx
::
op
::
identity
{},
i18
);
auto
i20
=
p
.
add_literal
(
1
);
auto
i21
=
p
.
add_literal
(
1
);
auto
i22
=
p
.
add_literal
(
2
);
auto
i23
=
p
.
add_instruction
(
nary_op
{
"i23"
},
i17
,
i22
,
i21
,
i19
,
i20
);
auto
i24
=
p
.
add_literal
(
1
);
auto
i25
=
p
.
add_instruction
(
nary_op
{
"i25"
},
i23
,
i24
);
auto
i26
=
p
.
add_literal
(
2
);
auto
i27
=
p
.
add_instruction
(
migraphx
::
op
::
identity
{},
i26
);
auto
i28
=
p
.
add_literal
(
1
);
auto
i29
=
p
.
add_literal
(
1
);
auto
i30
=
p
.
add_literal
(
2
);
auto
i31
=
p
.
add_instruction
(
nary_op
{
"i31"
},
i25
,
i30
,
i29
,
i27
,
i28
);
auto
i32
=
p
.
add_literal
(
2
);
auto
i33
=
p
.
add_instruction
(
migraphx
::
op
::
identity
{},
i32
);
auto
i34
=
p
.
add_literal
(
1
);
auto
i35
=
p
.
add_literal
(
1
);
auto
i36
=
p
.
add_literal
(
2
);
auto
i37
=
p
.
add_instruction
(
nary_op
{
"i37"
},
i31
,
i36
,
i35
,
i33
,
i34
);
auto
i38
=
p
.
add_literal
(
1
);
auto
i39
=
p
.
add_instruction
(
nary_op
{
"i39"
},
i37
,
i38
);
auto
i41
=
p
.
add_literal
(
2
);
auto
i42
=
p
.
add_instruction
(
migraphx
::
op
::
identity
{},
i41
);
auto
i43
=
p
.
add_literal
(
1
);
auto
i44
=
p
.
add_literal
(
1
);
auto
i45
=
p
.
add_literal
(
2
);
auto
i48
=
p
.
add_instruction
(
nary_op
{
"i48"
},
i39
,
i45
,
i44
,
i42
,
i43
);
auto
i49
=
p
.
add_literal
(
2
);
auto
i50
=
p
.
add_instruction
(
migraphx
::
op
::
identity
{},
i49
);
auto
i51
=
p
.
add_literal
(
1
);
auto
i52
=
p
.
add_literal
(
1
);
auto
i53
=
p
.
add_literal
(
2
);
auto
i54
=
p
.
add_instruction
(
nary_op
{
"i54"
},
i48
,
i53
,
i52
,
i50
,
i51
);
auto
i55
=
p
.
add_literal
(
1
);
auto
i56
=
p
.
add_instruction
(
migraphx
::
op
::
identity
{},
i55
);
auto
i57
=
p
.
add_literal
(
2
);
auto
i58
=
p
.
add_instruction
(
migraphx
::
op
::
identity
{},
i57
);
auto
i59
=
p
.
add_literal
(
1
);
auto
i60
=
p
.
add_literal
(
2
);
auto
i61
=
p
.
add_instruction
(
nary_op
{
"i61"
},
i54
,
i60
,
i59
,
i58
,
i56
);
auto
i62
=
p
.
add_literal
(
2
);
auto
i63
=
p
.
add_instruction
(
migraphx
::
op
::
identity
{},
i62
);
auto
i64
=
p
.
add_literal
(
1
);
auto
i65
=
p
.
add_literal
(
1
);
auto
i66
=
p
.
add_literal
(
2
);
auto
i69
=
p
.
add_instruction
(
nary_op
{
"i69"
},
i39
,
i66
,
i65
,
i63
,
i64
);
auto
i70
=
p
.
add_instruction
(
migraphx
::
op
::
identity
{},
i55
);
auto
i71
=
p
.
add_literal
(
2
);
auto
i72
=
p
.
add_instruction
(
migraphx
::
op
::
identity
{},
i71
);
auto
i73
=
p
.
add_literal
(
1
);
auto
i74
=
p
.
add_literal
(
2
);
auto
i75
=
p
.
add_instruction
(
nary_op
{
"i75"
},
i69
,
i74
,
i73
,
i72
,
i70
);
auto
i77
=
p
.
add_literal
(
1
);
auto
i80
=
p
.
add_instruction
(
nary_op
{
"i80"
},
i39
,
i77
);
auto
i81
=
p
.
add_instruction
(
migraphx
::
op
::
identity
{},
i55
);
auto
i82
=
p
.
add_literal
(
2
);
auto
i83
=
p
.
add_instruction
(
migraphx
::
op
::
identity
{},
i82
);
auto
i84
=
p
.
add_literal
(
1
);
auto
i85
=
p
.
add_literal
(
2
);
auto
i86
=
p
.
add_instruction
(
nary_op
{
"i86"
},
i80
,
i85
,
i84
,
i83
,
i81
);
auto
i88
=
p
.
add_instruction
(
migraphx
::
op
::
identity
{},
i55
);
auto
i89
=
p
.
add_literal
(
2
);
auto
i90
=
p
.
add_instruction
(
migraphx
::
op
::
identity
{},
i89
);
auto
i91
=
p
.
add_literal
(
1
);
auto
i92
=
p
.
add_literal
(
2
);
auto
i94
=
p
.
add_instruction
(
nary_op
{
"i94"
},
i39
,
i92
,
i91
,
i90
,
i88
);
auto
i96
=
p
.
add_instruction
(
migraphx
::
op
::
identity
{},
i55
,
i94
,
i75
,
i61
,
i86
);
auto
i97
=
p
.
add_literal
(
2
);
auto
i98
=
p
.
add_instruction
(
migraphx
::
op
::
identity
{},
i97
);
auto
i99
=
p
.
add_literal
(
3
);
auto
i100
=
p
.
add_literal
(
1
);
auto
i101
=
p
.
add_literal
(
2
);
auto
output
=
p
.
add_instruction
(
nary_op
{
"output"
},
i96
,
i101
,
i100
,
i98
,
i99
);
p
.
compile
(
t
);
EXPECT
(
t
.
get_streams
({
i7
,
i11
,
i17
,
i23
,
i25
,
i31
,
i37
,
i39
,
i94
})
==
t
.
get_streams
({
i7
,
i7
,
i7
,
i7
,
i7
,
i7
,
i7
,
i7
,
i7
}));
EXPECT
(
t
.
get_streams
({
i48
,
i54
,
i61
,
output
})
==
t
.
get_streams
({
output
,
output
,
output
,
output
}));
EXPECT
(
t
.
get_streams
({
i80
,
i86
})
==
t
.
get_streams
({
i80
,
i80
}));
EXPECT
(
t
.
get_streams
({
i69
,
i75
})
==
t
.
get_streams
({
i69
,
i69
}));
EXPECT
(
t
.
get_stream
(
i7
)
!=
t
.
get_stream
(
i80
));
EXPECT
(
t
.
get_stream
(
i69
)
!=
t
.
get_stream
(
i80
));
EXPECT
(
t
.
get_stream
(
i69
)
!=
t
.
get_stream
(
i7
));
EXPECT
(
t
.
get_stream
(
output
)
!=
t
.
get_stream
(
i7
));
EXPECT
(
t
.
get_stream
(
output
)
!=
t
.
get_stream
(
i69
));
EXPECT
(
t
.
get_stream
(
output
)
!=
t
.
get_stream
(
i80
));
EXPECT
(
get_wait_for
(
i48
)
==
get_wait_for
({
t
.
get_stream
(
i39
)}));
EXPECT
(
get_wait_for
(
i80
)
==
get_wait_for
({
t
.
get_stream
(
i39
)}));
EXPECT
(
get_wait_for
(
i69
)
==
get_wait_for
({
t
.
get_stream
(
i39
)}));
// We dont wait twice
EXPECT
(
get_wait_for
(
i94
).
empty
());
EXPECT
(
get_wait_for
(
output
)
==
get_wait_for
(
t
.
get_stream
(
output
),
{
t
.
get_stream
(
i94
),
t
.
get_stream
(
i75
),
t
.
get_stream
(
i61
),
t
.
get_stream
(
i86
)}));
check_conflicts
(
p
,
{{
i80
,
i86
},
{
i69
,
i75
},
{
i48
,
i54
,
i61
,
output
},
{
i94
}});
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment