Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
11e155c2
Commit
11e155c2
authored
Jun 13, 2022
by
Paul
Browse files
Merge
parents
8a9c5bce
aa7ff911
Changes
397
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
188 additions
and
96 deletions
+188
-96
src/cpp_generator.cpp
src/cpp_generator.cpp
+9
-1
src/dead_code_elimination.cpp
src/dead_code_elimination.cpp
+6
-22
src/driver/alexnet.cpp
src/driver/alexnet.cpp
+3
-3
src/driver/inceptionv3.cpp
src/driver/inceptionv3.cpp
+14
-14
src/driver/main.cpp
src/driver/main.cpp
+8
-0
src/driver/marker_roctx.cpp
src/driver/marker_roctx.cpp
+1
-1
src/driver/perf.cpp
src/driver/perf.cpp
+1
-1
src/driver/resnet50.cpp
src/driver/resnet50.cpp
+2
-2
src/eliminate_allocation.cpp
src/eliminate_allocation.cpp
+4
-4
src/eliminate_common_subexpression.cpp
src/eliminate_common_subexpression.cpp
+10
-5
src/eliminate_concat.cpp
src/eliminate_concat.cpp
+6
-6
src/eliminate_contiguous.cpp
src/eliminate_contiguous.cpp
+26
-11
src/eliminate_data_type.cpp
src/eliminate_data_type.cpp
+8
-2
src/eliminate_identity.cpp
src/eliminate_identity.cpp
+8
-8
src/eliminate_pad.cpp
src/eliminate_pad.cpp
+1
-1
src/include/migraphx/adjust_allocation.hpp
src/include/migraphx/adjust_allocation.hpp
+1
-1
src/include/migraphx/allocation_model.hpp
src/include/migraphx/allocation_model.hpp
+17
-12
src/include/migraphx/analyze_streams.hpp
src/include/migraphx/analyze_streams.hpp
+1
-1
src/include/migraphx/any_ptr.hpp
src/include/migraphx/any_ptr.hpp
+61
-0
src/include/migraphx/auto_contiguous.hpp
src/include/migraphx/auto_contiguous.hpp
+1
-1
No files found.
src/cpp_generator.cpp
View file @
11e155c2
...
@@ -88,6 +88,7 @@ struct cpp_generator_impl
...
@@ -88,6 +88,7 @@ struct cpp_generator_impl
std
::
stringstream
fs
{};
std
::
stringstream
fs
{};
std
::
size_t
function_count
=
0
;
std
::
size_t
function_count
=
0
;
std
::
function
<
std
::
string
(
std
::
string
)
>
fmap
=
nullptr
;
std
::
function
<
std
::
string
(
std
::
string
)
>
fmap
=
nullptr
;
std
::
function
<
std
::
string
(
shape
)
>
fresult
=
nullptr
;
std
::
unordered_map
<
std
::
string
,
std
::
string
>
point_op_map
=
{};
std
::
unordered_map
<
std
::
string
,
std
::
string
>
point_op_map
=
{};
};
};
cpp_generator
::
cpp_generator
()
:
impl
(
std
::
make_unique
<
cpp_generator_impl
>
())
{}
cpp_generator
::
cpp_generator
()
:
impl
(
std
::
make_unique
<
cpp_generator_impl
>
())
{}
...
@@ -104,6 +105,8 @@ cpp_generator::~cpp_generator() noexcept = default;
...
@@ -104,6 +105,8 @@ cpp_generator::~cpp_generator() noexcept = default;
void
cpp_generator
::
fmap
(
const
std
::
function
<
std
::
string
(
std
::
string
)
>&
f
)
{
impl
->
fmap
=
f
;
}
void
cpp_generator
::
fmap
(
const
std
::
function
<
std
::
string
(
std
::
string
)
>&
f
)
{
impl
->
fmap
=
f
;
}
void
cpp_generator
::
fresult
(
const
std
::
function
<
std
::
string
(
shape
)
>&
f
)
{
impl
->
fresult
=
f
;
}
void
cpp_generator
::
add_point_op
(
const
std
::
string
&
op_name
,
const
std
::
string
&
code
)
void
cpp_generator
::
add_point_op
(
const
std
::
string
&
op_name
,
const
std
::
string
&
code
)
{
{
impl
->
point_op_map
[
op_name
]
=
code
;
impl
->
point_op_map
[
op_name
]
=
code
;
...
@@ -174,7 +177,12 @@ cpp_generator::function cpp_generator::generate_module(const module& m)
...
@@ -174,7 +177,12 @@ cpp_generator::function cpp_generator::generate_module(const module& m)
ins
->
inputs
().
end
(),
ins
->
inputs
().
end
(),
std
::
back_inserter
(
args
),
std
::
back_inserter
(
args
),
[
&
](
auto
i
)
{
return
names
.
at
(
i
);
});
[
&
](
auto
i
)
{
return
names
.
at
(
i
);
});
return
this
->
generate_point_op
(
ins
->
get_operator
(),
args
);
auto
s
=
this
->
generate_point_op
(
ins
->
get_operator
(),
args
);
if
(
impl
->
fresult
)
return
impl
->
fresult
(
ins
->
get_shape
())
+
'('
+
s
+
')'
;
else
return
s
;
});
});
return
f
;
return
f
;
}
}
...
...
src/dead_code_elimination.cpp
View file @
11e155c2
...
@@ -9,26 +9,6 @@
...
@@ -9,26 +9,6 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
class
Range
,
class
Iterator
>
std
::
ptrdiff_t
bidistance
(
const
Range
&
r
,
Iterator
start
,
Iterator
last
)
{
auto
start_forward
=
start
;
auto
start_backwards
=
start
;
std
::
size_t
n
=
0
;
while
(
start_forward
!=
last
and
start_backwards
!=
last
)
{
n
++
;
if
(
start_forward
!=
r
.
end
())
start_forward
++
;
if
(
start_backwards
!=
r
.
begin
())
start_backwards
--
;
}
if
(
start_forward
==
last
)
return
n
;
else
return
-
n
;
}
void
dead_code_elimination
::
apply
(
program
&
p
)
const
{
p
.
remove_unused_modules
();
}
void
dead_code_elimination
::
apply
(
program
&
p
)
const
{
p
.
remove_unused_modules
();
}
void
dead_code_elimination
::
apply
(
module
&
m
)
const
void
dead_code_elimination
::
apply
(
module
&
m
)
const
...
@@ -48,17 +28,21 @@ void dead_code_elimination::apply(module& m) const
...
@@ -48,17 +28,21 @@ void dead_code_elimination::apply(module& m) const
if
(
i
->
get_shape
().
elements
()
==
0
and
i
->
name
().
front
()
!=
'@'
and
if
(
i
->
get_shape
().
elements
()
==
0
and
i
->
name
().
front
()
!=
'@'
and
i
->
name
()
!=
"undefined"
and
i
->
name
()
!=
"identity"
)
i
->
name
()
!=
"undefined"
and
i
->
name
()
!=
"identity"
)
continue
;
continue
;
assert
(
bidistance
(
m
,
i
,
last
)
>
0
);
assert
(
std
::
distance
(
m
.
begin
(),
i
)
<=
std
::
distance
(
m
.
begin
(),
last
));
std
::
unordered_set
<
instruction_ref
>
visited
;
fix
([
&
](
auto
self
,
auto
leaf
)
{
fix
([
&
](
auto
self
,
auto
leaf
)
{
if
(
not
m
.
has_instruction
(
leaf
))
if
(
not
m
.
has_instruction
(
leaf
))
return
;
return
;
if
(
leaf
->
outputs
().
empty
())
if
(
leaf
->
outputs
().
empty
())
{
{
// Dont visit inputs twice
if
(
not
visited
.
insert
(
leaf
).
second
)
return
;
std
::
unordered_set
<
instruction_ref
>
args
(
leaf
->
inputs
().
begin
(),
std
::
unordered_set
<
instruction_ref
>
args
(
leaf
->
inputs
().
begin
(),
leaf
->
inputs
().
end
());
leaf
->
inputs
().
end
());
leaf
->
clear_arguments
();
leaf
->
clear_arguments
();
assert
(
bi
distance
(
m
,
last
,
leaf
)
<
0
);
assert
(
std
::
distance
(
m
.
begin
(),
leaf
)
<
std
::
distance
(
m
.
begin
(),
last
)
);
assert
(
leaf
!=
ins
);
assert
(
leaf
!=
ins
);
if
(
leaf
->
name
()
!=
"@param"
)
if
(
leaf
->
name
()
!=
"@param"
)
m
.
move_instruction
(
leaf
,
m
.
end
());
m
.
move_instruction
(
leaf
,
m
.
end
());
...
...
src/driver/alexnet.cpp
View file @
11e155c2
...
@@ -61,7 +61,7 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
...
@@ -61,7 +61,7 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx
::
op
::
relu
relu19
;
migraphx
::
op
::
relu
relu19
;
auto
mx19
=
mm
->
add_instruction
(
relu19
,
mx18
);
auto
mx19
=
mm
->
add_instruction
(
relu19
,
mx18
);
migraphx
::
op
::
pooling
pooling20
;
migraphx
::
op
::
pooling
pooling20
;
pooling20
.
mode
=
"max"
;
pooling20
.
mode
=
migraphx
::
op
::
pooling_mode
::
max
;
pooling20
.
padding
=
{
0
,
0
};
pooling20
.
padding
=
{
0
,
0
};
pooling20
.
stride
=
{
2
,
2
};
pooling20
.
stride
=
{
2
,
2
};
pooling20
.
lengths
=
{
3
,
3
};
pooling20
.
lengths
=
{
3
,
3
};
...
@@ -81,7 +81,7 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
...
@@ -81,7 +81,7 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx
::
op
::
relu
relu24
;
migraphx
::
op
::
relu
relu24
;
auto
mx24
=
mm
->
add_instruction
(
relu24
,
mx23
);
auto
mx24
=
mm
->
add_instruction
(
relu24
,
mx23
);
migraphx
::
op
::
pooling
pooling25
;
migraphx
::
op
::
pooling
pooling25
;
pooling25
.
mode
=
"max"
;
pooling25
.
mode
=
migraphx
::
op
::
pooling_mode
::
max
;
pooling25
.
padding
=
{
0
,
0
};
pooling25
.
padding
=
{
0
,
0
};
pooling25
.
stride
=
{
2
,
2
};
pooling25
.
stride
=
{
2
,
2
};
pooling25
.
lengths
=
{
3
,
3
};
pooling25
.
lengths
=
{
3
,
3
};
...
@@ -129,7 +129,7 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
...
@@ -129,7 +129,7 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx
::
op
::
relu
relu37
;
migraphx
::
op
::
relu
relu37
;
auto
mx37
=
mm
->
add_instruction
(
relu37
,
mx36
);
auto
mx37
=
mm
->
add_instruction
(
relu37
,
mx36
);
migraphx
::
op
::
pooling
pooling38
;
migraphx
::
op
::
pooling
pooling38
;
pooling38
.
mode
=
"max"
;
pooling38
.
mode
=
migraphx
::
op
::
pooling_mode
::
max
;
pooling38
.
padding
=
{
0
,
0
};
pooling38
.
padding
=
{
0
,
0
};
pooling38
.
stride
=
{
2
,
2
};
pooling38
.
stride
=
{
2
,
2
};
pooling38
.
lengths
=
{
3
,
3
};
pooling38
.
lengths
=
{
3
,
3
};
...
...
src/driver/inceptionv3.cpp
View file @
11e155c2
...
@@ -995,7 +995,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
...
@@ -995,7 +995,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu492;
migraphx::op::relu relu492;
auto mx492 = mm->add_instruction(relu492, mx491);
auto mx492 = mm->add_instruction(relu492, mx491);
migraphx::op::pooling pooling493;
migraphx::op::pooling pooling493;
pooling493
.
mode
=
"max"
;
pooling493.mode =
migraphx::op::pooling_mode::max
;
pooling493.padding = {0, 0};
pooling493.padding = {0, 0};
pooling493.stride = {2, 2};
pooling493.stride = {2, 2};
pooling493.lengths = {3, 3};
pooling493.lengths = {3, 3};
...
@@ -1025,7 +1025,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
...
@@ -1025,7 +1025,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu499;
migraphx::op::relu relu499;
auto mx499 = mm->add_instruction(relu499, mx498);
auto mx499 = mm->add_instruction(relu499, mx498);
migraphx::op::pooling pooling500;
migraphx::op::pooling pooling500;
pooling500
.
mode
=
"max"
;
pooling500.mode =
migraphx::op::pooling_mode::max
;
pooling500.padding = {0, 0};
pooling500.padding = {0, 0};
pooling500.stride = {2, 2};
pooling500.stride = {2, 2};
pooling500.lengths = {3, 3};
pooling500.lengths = {3, 3};
...
@@ -1103,7 +1103,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
...
@@ -1103,7 +1103,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu518;
migraphx::op::relu relu518;
auto mx518 = mm->add_instruction(relu518, mx517);
auto mx518 = mm->add_instruction(relu518, mx517);
migraphx::op::pooling pooling519;
migraphx::op::pooling pooling519;
pooling519
.
mode
=
"
average
"
;
pooling519.mode =
migraphx::op::pooling_mode::
average;
pooling519.padding = {1, 1};
pooling519.padding = {1, 1};
pooling519.stride = {1, 1};
pooling519.stride = {1, 1};
pooling519.lengths = {3, 3};
pooling519.lengths = {3, 3};
...
@@ -1196,7 +1196,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
...
@@ -1196,7 +1196,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu541;
migraphx::op::relu relu541;
auto mx541 = mm->add_instruction(relu541, mx540);
auto mx541 = mm->add_instruction(relu541, mx540);
migraphx::op::pooling pooling542;
migraphx::op::pooling pooling542;
pooling542
.
mode
=
"
average
"
;
pooling542.mode =
migraphx::op::pooling_mode::
average;
pooling542.padding = {1, 1};
pooling542.padding = {1, 1};
pooling542.stride = {1, 1};
pooling542.stride = {1, 1};
pooling542.lengths = {3, 3};
pooling542.lengths = {3, 3};
...
@@ -1289,7 +1289,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
...
@@ -1289,7 +1289,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu564;
migraphx::op::relu relu564;
auto mx564 = mm->add_instruction(relu564, mx563);
auto mx564 = mm->add_instruction(relu564, mx563);
migraphx::op::pooling pooling565;
migraphx::op::pooling pooling565;
pooling565
.
mode
=
"
average
"
;
pooling565.mode =
migraphx::op::pooling_mode::
average;
pooling565.padding = {1, 1};
pooling565.padding = {1, 1};
pooling565.stride = {1, 1};
pooling565.stride = {1, 1};
pooling565.lengths = {3, 3};
pooling565.lengths = {3, 3};
...
@@ -1358,7 +1358,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
...
@@ -1358,7 +1358,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu581;
migraphx::op::relu relu581;
auto mx581 = mm->add_instruction(relu581, mx580);
auto mx581 = mm->add_instruction(relu581, mx580);
migraphx::op::pooling pooling582;
migraphx::op::pooling pooling582;
pooling582
.
mode
=
"max"
;
pooling582.mode =
migraphx::op::pooling_mode::max
;
pooling582.padding = {0, 0};
pooling582.padding = {0, 0};
pooling582.stride = {2, 2};
pooling582.stride = {2, 2};
pooling582.lengths = {3, 3};
pooling582.lengths = {3, 3};
...
@@ -1475,7 +1475,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
...
@@ -1475,7 +1475,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu610;
migraphx::op::relu relu610;
auto mx610 = mm->add_instruction(relu610, mx609);
auto mx610 = mm->add_instruction(relu610, mx609);
migraphx::op::pooling pooling611;
migraphx::op::pooling pooling611;
pooling611
.
mode
=
"
average
"
;
pooling611.mode =
migraphx::op::pooling_mode::
average;
pooling611.padding = {1, 1};
pooling611.padding = {1, 1};
pooling611.stride = {1, 1};
pooling611.stride = {1, 1};
pooling611.lengths = {3, 3};
pooling611.lengths = {3, 3};
...
@@ -1604,7 +1604,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
...
@@ -1604,7 +1604,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu642;
migraphx::op::relu relu642;
auto mx642 = mm->add_instruction(relu642, mx641);
auto mx642 = mm->add_instruction(relu642, mx641);
migraphx::op::pooling pooling643;
migraphx::op::pooling pooling643;
pooling643
.
mode
=
"
average
"
;
pooling643.mode =
migraphx::op::pooling_mode::
average;
pooling643.padding = {1, 1};
pooling643.padding = {1, 1};
pooling643.stride = {1, 1};
pooling643.stride = {1, 1};
pooling643.lengths = {3, 3};
pooling643.lengths = {3, 3};
...
@@ -1733,7 +1733,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
...
@@ -1733,7 +1733,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu674;
migraphx::op::relu relu674;
auto mx674 = mm->add_instruction(relu674, mx673);
auto mx674 = mm->add_instruction(relu674, mx673);
migraphx::op::pooling pooling675;
migraphx::op::pooling pooling675;
pooling675
.
mode
=
"
average
"
;
pooling675.mode =
migraphx::op::pooling_mode::
average;
pooling675.padding = {1, 1};
pooling675.padding = {1, 1};
pooling675.stride = {1, 1};
pooling675.stride = {1, 1};
pooling675.lengths = {3, 3};
pooling675.lengths = {3, 3};
...
@@ -1862,7 +1862,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
...
@@ -1862,7 +1862,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu706;
migraphx::op::relu relu706;
auto mx706 = mm->add_instruction(relu706, mx705);
auto mx706 = mm->add_instruction(relu706, mx705);
migraphx::op::pooling pooling707;
migraphx::op::pooling pooling707;
pooling707
.
mode
=
"
average
"
;
pooling707.mode =
migraphx::op::pooling_mode::
average;
pooling707.padding = {1, 1};
pooling707.padding = {1, 1};
pooling707.stride = {1, 1};
pooling707.stride = {1, 1};
pooling707.lengths = {3, 3};
pooling707.lengths = {3, 3};
...
@@ -1955,7 +1955,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
...
@@ -1955,7 +1955,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu729;
migraphx::op::relu relu729;
auto mx729 = mm->add_instruction(relu729, mx728);
auto mx729 = mm->add_instruction(relu729, mx728);
migraphx::op::pooling pooling730;
migraphx::op::pooling pooling730;
pooling730
.
mode
=
"max"
;
pooling730.mode =
migraphx::op::pooling_mode::max
;
pooling730.padding = {0, 0};
pooling730.padding = {0, 0};
pooling730.stride = {2, 2};
pooling730.stride = {2, 2};
pooling730.lengths = {3, 3};
pooling730.lengths = {3, 3};
...
@@ -2066,7 +2066,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
...
@@ -2066,7 +2066,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
concat757.axis = 1;
concat757.axis = 1;
auto mx757 = mm->add_instruction(concat757, mx753, mx756);
auto mx757 = mm->add_instruction(concat757, mx753, mx756);
migraphx::op::pooling pooling758;
migraphx::op::pooling pooling758;
pooling758
.
mode
=
"
average
"
;
pooling758.mode =
migraphx::op::pooling_mode::
average;
pooling758.padding = {1, 1};
pooling758.padding = {1, 1};
pooling758.stride = {1, 1};
pooling758.stride = {1, 1};
pooling758.lengths = {3, 3};
pooling758.lengths = {3, 3};
...
@@ -2189,7 +2189,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
...
@@ -2189,7 +2189,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
concat788.axis = 1;
concat788.axis = 1;
auto mx788 = mm->add_instruction(concat788, mx784, mx787);
auto mx788 = mm->add_instruction(concat788, mx784, mx787);
migraphx::op::pooling pooling789;
migraphx::op::pooling pooling789;
pooling789
.
mode
=
"
average
"
;
pooling789.mode =
migraphx::op::pooling_mode::
average;
pooling789.padding = {1, 1};
pooling789.padding = {1, 1};
pooling789.stride = {1, 1};
pooling789.stride = {1, 1};
pooling789.lengths = {3, 3};
pooling789.lengths = {3, 3};
...
@@ -2210,7 +2210,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
...
@@ -2210,7 +2210,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
concat793.axis = 1;
concat793.axis = 1;
auto mx793 = mm->add_instruction(concat793, mx765, mx775, mx788, mx792);
auto mx793 = mm->add_instruction(concat793, mx765, mx775, mx788, mx792);
migraphx::op::pooling pooling794;
migraphx::op::pooling pooling794;
pooling794
.
mode
=
"
average
"
;
pooling794.mode =
migraphx::op::pooling_mode::
average;
pooling794.padding = {0, 0};
pooling794.padding = {0, 0};
pooling794.stride = {8, 8};
pooling794.stride = {8, 8};
pooling794.lengths = {8, 8};
pooling794.lengths = {8, 8};
...
...
src/driver/main.cpp
View file @
11e155c2
...
@@ -508,8 +508,10 @@ struct roctx : command<roctx>
...
@@ -508,8 +508,10 @@ struct roctx : command<roctx>
struct
op
:
command
<
op
>
struct
op
:
command
<
op
>
{
{
bool
show_ops
=
false
;
bool
show_ops
=
false
;
std
::
string
op_name
{};
void
parse
(
argument_parser
&
ap
)
void
parse
(
argument_parser
&
ap
)
{
{
ap
(
op_name
,
{},
ap
.
metavar
(
"<MIGraphX operator name>"
));
ap
(
show_ops
,
ap
(
show_ops
,
{
"--list"
,
"-l"
},
{
"--list"
,
"-l"
},
ap
.
help
(
"List all the operators of MIGraphX"
),
ap
.
help
(
"List all the operators of MIGraphX"
),
...
@@ -522,6 +524,12 @@ struct op : command<op>
...
@@ -522,6 +524,12 @@ struct op : command<op>
for
(
const
auto
&
name
:
get_operators
())
for
(
const
auto
&
name
:
get_operators
())
std
::
cout
<<
name
<<
std
::
endl
;
std
::
cout
<<
name
<<
std
::
endl
;
}
}
else
{
auto
op
=
load_op
(
op_name
);
std
::
cout
<<
op_name
<<
": "
<<
std
::
endl
;
std
::
cout
<<
to_pretty_json_string
(
op
.
to_value
())
<<
std
::
endl
;
}
}
}
};
};
...
...
src/driver/marker_roctx.cpp
View file @
11e155c2
...
@@ -17,7 +17,7 @@ class marker_roctx
...
@@ -17,7 +17,7 @@ class marker_roctx
std
::
function
<
int
(
const
char
*
)
>
sym_roctx_range_push
;
std
::
function
<
int
(
const
char
*
)
>
sym_roctx_range_push
;
std
::
function
<
int
()
>
sym_roctx_range_pop
;
std
::
function
<
int
()
>
sym_roctx_range_pop
;
uint64_t
range_id
;
uint64_t
range_id
=
0
;
public:
public:
marker_roctx
()
marker_roctx
()
...
...
src/driver/perf.cpp
View file @
11e155c2
...
@@ -87,6 +87,6 @@ target get_target(bool gpu)
...
@@ -87,6 +87,6 @@ target get_target(bool gpu)
void
compile_program
(
program
&
p
,
bool
gpu
)
{
p
.
compile
(
get_target
(
gpu
));
}
void
compile_program
(
program
&
p
,
bool
gpu
)
{
p
.
compile
(
get_target
(
gpu
));
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace
MIGRAPHX_INLINE_NS
}
// namespace driver
}
// namespace driver
}
// namespace migraphx
}
// namespace migraphx
src/driver/resnet50.cpp
View file @
11e155c2
...
@@ -561,7 +561,7 @@ migraphx::program resnet50(unsigned batch) // NOLINT(readability-function-size)
...
@@ -561,7 +561,7 @@ migraphx::program resnet50(unsigned batch) // NOLINT(readability-function-size)
migraphx
::
op
::
relu
relu269
;
migraphx
::
op
::
relu
relu269
;
auto
mx269
=
mm
->
add_instruction
(
relu269
,
mx268
);
auto
mx269
=
mm
->
add_instruction
(
relu269
,
mx268
);
migraphx
::
op
::
pooling
pooling270
;
migraphx
::
op
::
pooling
pooling270
;
pooling270
.
mode
=
"max"
;
pooling270
.
mode
=
migraphx
::
op
::
pooling_mode
::
max
;
pooling270
.
padding
=
{
1
,
1
};
pooling270
.
padding
=
{
1
,
1
};
pooling270
.
stride
=
{
2
,
2
};
pooling270
.
stride
=
{
2
,
2
};
pooling270
.
lengths
=
{
3
,
3
};
pooling270
.
lengths
=
{
3
,
3
};
...
@@ -1215,7 +1215,7 @@ migraphx::program resnet50(unsigned batch) // NOLINT(readability-function-size)
...
@@ -1215,7 +1215,7 @@ migraphx::program resnet50(unsigned batch) // NOLINT(readability-function-size)
migraphx
::
op
::
relu
relu438
;
migraphx
::
op
::
relu
relu438
;
auto
mx438
=
mm
->
add_instruction
(
relu438
,
mx437
);
auto
mx438
=
mm
->
add_instruction
(
relu438
,
mx437
);
migraphx
::
op
::
pooling
pooling439
;
migraphx
::
op
::
pooling
pooling439
;
pooling439
.
mode
=
"
average
"
;
pooling439
.
mode
=
migraphx
::
op
::
pooling_mode
::
average
;
pooling439
.
padding
=
{
0
,
0
};
pooling439
.
padding
=
{
0
,
0
};
pooling439
.
stride
=
{
1
,
1
};
pooling439
.
stride
=
{
1
,
1
};
pooling439
.
lengths
=
{
7
,
7
};
pooling439
.
lengths
=
{
7
,
7
};
...
...
src/eliminate_allocation.cpp
View file @
11e155c2
...
@@ -13,13 +13,13 @@
...
@@ -13,13 +13,13 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
eliminate_allocation
::
apply
(
module
&
p
)
const
void
eliminate_allocation
::
apply
(
module
&
m
)
const
{
{
assert
(
alignment
>
0
);
assert
(
alignment
>
0
);
std
::
size_t
n
=
0
;
std
::
size_t
n
=
0
;
std
::
vector
<
std
::
pair
<
instruction_ref
,
std
::
size_t
>>
allocs
;
std
::
vector
<
std
::
pair
<
instruction_ref
,
std
::
size_t
>>
allocs
;
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
m
))
{
{
if
(
ins
->
name
()
!=
allocation_op
)
if
(
ins
->
name
()
!=
allocation_op
)
continue
;
continue
;
...
@@ -30,13 +30,13 @@ void eliminate_allocation::apply(module& p) const
...
@@ -30,13 +30,13 @@ void eliminate_allocation::apply(module& p) const
}
}
if
(
n
>
0
)
if
(
n
>
0
)
{
{
auto
mem
=
p
.
add_parameter
(
"memory"
,
shape
{
shape
::
int8_type
,
{
n
}});
auto
mem
=
m
.
add_parameter
(
"memory"
,
shape
{
shape
::
int8_type
,
{
n
}});
for
(
auto
&&
pp
:
allocs
)
for
(
auto
&&
pp
:
allocs
)
{
{
auto
ins
=
pp
.
first
;
auto
ins
=
pp
.
first
;
auto
s
=
ins
->
get_shape
();
auto
s
=
ins
->
get_shape
();
auto
offset
=
pp
.
second
;
auto
offset
=
pp
.
second
;
p
.
replace_instruction
(
m
.
replace_instruction
(
ins
,
make_op
(
"load"
,
{{
"shape"
,
to_value
(
s
)},
{
"offset"
,
offset
}}),
mem
);
ins
,
make_op
(
"load"
,
{{
"shape"
,
to_value
(
s
)},
{
"offset"
,
offset
}}),
mem
);
}
}
}
}
...
...
src/eliminate_common_subexpression.cpp
View file @
11e155c2
...
@@ -11,7 +11,7 @@ namespace migraphx {
...
@@ -11,7 +11,7 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
class
Range
>
template
<
class
Range
>
void
cse_range
(
module
&
p
,
Range
&&
r
)
void
cse_range
(
module
&
m
,
Range
&&
r
)
{
{
std
::
unordered_multimap
<
std
::
string
,
instruction_ref
>
instructions
;
std
::
unordered_multimap
<
std
::
string
,
instruction_ref
>
instructions
;
std
::
unordered_set
<
instruction_ref
>
processed_ins
;
std
::
unordered_set
<
instruction_ref
>
processed_ins
;
...
@@ -30,19 +30,24 @@ void cse_range(module& p, Range&& r)
...
@@ -30,19 +30,24 @@ void cse_range(module& p, Range&& r)
continue
;
continue
;
if
(
*
eq
!=
*
ins
)
if
(
*
eq
!=
*
ins
)
continue
;
continue
;
p
.
replace_instruction
(
ins
,
eq
);
m
.
replace_instruction
(
ins
,
eq
);
processed_ins
.
emplace
(
ins
);
processed_ins
.
emplace
(
ins
);
auto
outputs
=
eq
->
outputs
();
std
::
vector
<
instruction_ref
>
outputs
;
std
::
copy_if
(
eq
->
outputs
().
begin
(),
eq
->
outputs
().
end
(),
std
::
back_inserter
(
outputs
),
[
&
](
auto
x
)
{
return
m
.
has_instruction
(
x
);
});
std
::
sort
(
outputs
.
begin
(),
outputs
.
end
(),
[
&
](
auto
x
,
auto
y
)
{
std
::
sort
(
outputs
.
begin
(),
outputs
.
end
(),
[
&
](
auto
x
,
auto
y
)
{
return
std
::
distance
(
eq
,
x
)
<
std
::
distance
(
eq
,
y
);
return
std
::
distance
(
eq
,
x
)
<
std
::
distance
(
eq
,
y
);
});
});
cse_range
(
p
,
outputs
);
cse_range
(
m
,
outputs
);
}
}
instructions
.
emplace
(
ins
->
name
(),
ins
);
instructions
.
emplace
(
ins
->
name
(),
ins
);
}
}
}
}
void
eliminate_common_subexpression
::
apply
(
module
&
p
)
const
{
cse_range
(
p
,
iterator_for
(
p
));
}
void
eliminate_common_subexpression
::
apply
(
module
&
m
)
const
{
cse_range
(
m
,
iterator_for
(
m
));
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
src/eliminate_concat.cpp
View file @
11e155c2
...
@@ -13,9 +13,9 @@
...
@@ -13,9 +13,9 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
eliminate_concat
::
apply
(
module
&
p
)
const
void
eliminate_concat
::
apply
(
module
&
m
)
const
{
{
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
m
))
{
{
// Look for the concat operator
// Look for the concat operator
if
(
ins
->
name
()
!=
concat_opt
.
name
())
if
(
ins
->
name
()
!=
concat_opt
.
name
())
...
@@ -64,22 +64,22 @@ void eliminate_concat::apply(module& p) const
...
@@ -64,22 +64,22 @@ void eliminate_concat::apply(module& p) const
std
::
sort
(
sorted_allocations
.
begin
(),
std
::
sort
(
sorted_allocations
.
begin
(),
sorted_allocations
.
end
(),
sorted_allocations
.
end
(),
[
&
](
instruction_ref
x
,
instruction_ref
y
)
{
[
&
](
instruction_ref
x
,
instruction_ref
y
)
{
return
std
::
distance
(
p
.
begin
(),
x
)
<
std
::
distance
(
p
.
begin
(),
y
);
return
std
::
distance
(
m
.
begin
(),
x
)
<
std
::
distance
(
m
.
begin
(),
y
);
});
});
// Move "super" allocation to the front
// Move "super" allocation to the front
auto
first
=
sorted_allocations
.
front
();
auto
first
=
sorted_allocations
.
front
();
auto
super
=
p
.
move_instruction
(
last
,
first
);
auto
super
=
m
.
move_instruction
(
last
,
first
);
// Replace each allocation with a load
// Replace each allocation with a load
std
::
size_t
offset
=
0
;
std
::
size_t
offset
=
0
;
for
(
auto
alloc
:
allocations
)
for
(
auto
alloc
:
allocations
)
{
{
op
::
load
op
{
alloc
->
get_shape
(),
offset
};
op
::
load
op
{
alloc
->
get_shape
(),
offset
};
p
.
replace_instruction
(
alloc
,
op
,
{
super
});
m
.
replace_instruction
(
alloc
,
op
,
{
super
});
offset
+=
alloc
->
get_shape
().
bytes
();
offset
+=
alloc
->
get_shape
().
bytes
();
}
}
std
::
vector
<
instruction_ref
>
args
=
{
super
};
std
::
vector
<
instruction_ref
>
args
=
{
super
};
std
::
copy
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
()
-
1
,
std
::
back_inserter
(
args
));
std
::
copy
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
()
-
1
,
std
::
back_inserter
(
args
));
p
.
replace_instruction
(
ins
,
migraphx
::
make_op
(
"identity"
),
args
);
m
.
replace_instruction
(
ins
,
migraphx
::
make_op
(
"identity"
),
args
);
}
}
}
}
}
}
...
...
src/eliminate_contiguous.cpp
View file @
11e155c2
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/op/contiguous.hpp>
#include <migraphx/op/contiguous.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/par_for.hpp>
#include <utility>
#include <utility>
namespace
migraphx
{
namespace
migraphx
{
...
@@ -69,38 +70,52 @@ static bool try_compute_shape(instruction_ref ins,
...
@@ -69,38 +70,52 @@ static bool try_compute_shape(instruction_ref ins,
return
try_compute_shape
(
ins
,
inputs
,
mods
);
return
try_compute_shape
(
ins
,
inputs
,
mods
);
}
}
void
eliminate_contiguous
::
apply
(
module
&
p
)
const
void
eliminate_contiguous
::
apply
(
module
&
m
)
const
{
{
for
(
auto
ins
:
iterator_for
(
p
))
std
::
vector
<
instruction_ref
>
const_instruction
;
for
(
auto
ins
:
iterator_for
(
m
))
{
{
// return instruction should have inputs with standard shape
// return instruction should have inputs with standard shape
if
(
ins
->
name
()
==
"@return"
)
if
(
ins
->
name
()
==
"@return"
)
continue
;
continue
;
// Make a copy so we can modify it while we iterate
// Make a copy so we can modify it while we iterate
auto
args
=
ins
->
inputs
();
auto
args
=
ins
->
inputs
();
auto
new_args
=
args
;
auto
mod_args
=
ins
->
module_inputs
();
for
(
auto
arg
:
ins
->
inputs
())
for
(
auto
arg
:
ins
->
inputs
())
{
{
if
(
arg
->
name
()
==
op_name
)
if
(
arg
->
name
()
==
op_name
)
{
{
auto
new_args
=
args
;
auto
prev
=
arg
->
inputs
().
front
();
auto
prev
=
arg
->
inputs
().
front
();
replace
(
new_args
,
arg
,
prev
);
replace
(
new_args
,
arg
,
prev
);
if
(
try_compute_shape
(
ins
,
new_args
,
ins
->
module_inputs
()
))
if
(
try_compute_shape
(
ins
,
new_args
,
mod_args
))
{
{
instruction
::
replace_argument
(
ins
,
arg
,
prev
);
instruction
::
replace_argument
(
ins
,
arg
,
prev
);
}
}
else
if
(
prev
->
can_eval
())
else
if
(
prev
->
can_eval
())
{
{
auto
c
=
op
::
contiguous
{};
const_instruction
.
push_back
(
arg
);
auto
r
=
c
.
compute
(
c
.
compute_shape
({
prev
->
get_shape
()}),
{
prev
->
eval
()});
auto
l
=
p
.
add_literal
(
r
.
get_shape
(),
r
.
data
());
p
.
replace_instruction
(
arg
,
l
);
}
}
}
}
}
}
}
}
// Perform evaluations in parallel
std
::
vector
<
argument
>
literals
(
const_instruction
.
size
());
par_for
(
const_instruction
.
size
(),
1
,
[
&
](
const
auto
i
)
{
auto
c
=
op
::
contiguous
{};
auto
prev
=
const_instruction
[
i
]
->
inputs
().
front
();
literals
[
i
]
=
c
.
compute
(
c
.
compute_shape
({
prev
->
get_shape
()}),
{
prev
->
eval
()});
});
for
(
size_t
i
=
0
;
i
<
const_instruction
.
size
();
i
++
)
{
auto
l
=
m
.
add_literal
(
literals
[
i
].
get_shape
(),
literals
[
i
].
data
());
m
.
replace_instruction
(
const_instruction
[
i
],
l
);
}
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/eliminate_data_type.cpp
View file @
11e155c2
...
@@ -10,8 +10,14 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -10,8 +10,14 @@ inline namespace MIGRAPHX_INLINE_NS {
void
eliminate_data_type
::
apply
(
module
&
m
)
const
void
eliminate_data_type
::
apply
(
module
&
m
)
const
{
{
static
const
std
::
vector
<
std
::
string
>
skip_op_names
=
{
static
const
std
::
vector
<
std
::
string
>
skip_op_names
=
{
"convert"
,
"convert"
,
"get_tuple_elem"
,
"if"
,
"loop"
,
"roialign"
};
"get_tuple_elem"
,
"if"
,
"loop"
,
"roialign"
,
"scatternd_add"
,
"scatternd_mul"
,
"scatternd_none"
};
for
(
auto
ins
:
iterator_for
(
m
))
for
(
auto
ins
:
iterator_for
(
m
))
{
{
if
(
ins
->
name
()[
0
]
==
'@'
)
if
(
ins
->
name
()[
0
]
==
'@'
)
...
...
src/eliminate_identity.cpp
View file @
11e155c2
...
@@ -8,21 +8,21 @@
...
@@ -8,21 +8,21 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
eliminate_identity
::
apply
(
module
&
p
)
const
void
eliminate_identity
::
apply
(
module
&
m
)
const
{
{
auto
last
=
std
::
prev
(
p
.
end
());
auto
last
=
std
::
prev
(
m
.
end
());
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
m
))
{
{
// Skip the first instruction, since we always process the previous
// Skip the first instruction, since we always process the previous
// instruction
// instruction
if
(
ins
==
p
.
begin
())
if
(
ins
==
m
.
begin
())
continue
;
continue
;
const
auto
i
=
std
::
prev
(
ins
);
const
auto
i
=
std
::
prev
(
ins
);
if
(
i
->
name
()
==
"identity"
)
if
(
i
->
name
()
==
"identity"
)
{
{
p
.
replace_instruction
(
i
,
i
->
inputs
().
front
());
m
.
replace_instruction
(
i
,
i
->
inputs
().
front
());
p
.
move_instruction
(
i
,
p
.
end
());
m
.
move_instruction
(
i
,
m
.
end
());
}
}
if
(
ins
==
last
)
if
(
ins
==
last
)
{
{
...
@@ -31,7 +31,7 @@ void eliminate_identity::apply(module& p) const
...
@@ -31,7 +31,7 @@ void eliminate_identity::apply(module& p) const
const
instruction_ref
&
identity_input
=
ins
->
inputs
().
front
();
const
instruction_ref
&
identity_input
=
ins
->
inputs
().
front
();
if
(
identity_input
->
outputs
().
size
()
==
1
)
if
(
identity_input
->
outputs
().
size
()
==
1
)
{
{
p
.
move_instruction
(
identity_input
,
i
);
m
.
move_instruction
(
identity_input
,
i
);
// since this is the last instruction, removing it only
// since this is the last instruction, removing it only
// requires changing "last" and calling remove below
// requires changing "last" and calling remove below
last
=
std
::
prev
(
last
);
last
=
std
::
prev
(
last
);
...
@@ -40,7 +40,7 @@ void eliminate_identity::apply(module& p) const
...
@@ -40,7 +40,7 @@ void eliminate_identity::apply(module& p) const
break
;
break
;
}
}
}
}
p
.
remove_instructions
(
std
::
next
(
last
),
p
.
end
());
m
.
remove_instructions
(
std
::
next
(
last
),
m
.
end
());
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/eliminate_pad.cpp
View file @
11e155c2
...
@@ -44,7 +44,7 @@ static void update_op(const instruction_ref& input, const instruction_ref& ins,
...
@@ -44,7 +44,7 @@ static void update_op(const instruction_ref& input, const instruction_ref& ins,
static
void
update_pooling
(
const
instruction_ref
&
input
,
const
instruction_ref
&
ins
,
module
&
m
)
static
void
update_pooling
(
const
instruction_ref
&
input
,
const
instruction_ref
&
ins
,
module
&
m
)
{
{
auto
op
=
any_cast
<
op
::
pooling
>
(
ins
->
get_operator
());
auto
op
=
any_cast
<
op
::
pooling
>
(
ins
->
get_operator
());
if
(
op
.
mode
==
"
average
"
)
if
(
op
.
mode
==
op
::
pooling_mode
::
average
)
{
{
return
;
return
;
}
}
...
...
src/include/migraphx/adjust_allocation.hpp
View file @
11e155c2
...
@@ -13,7 +13,7 @@ struct adjust_allocation
...
@@ -13,7 +13,7 @@ struct adjust_allocation
{
{
allocation_model
model
;
allocation_model
model
;
std
::
string
name
()
const
{
return
"adjust_allocation"
;
}
std
::
string
name
()
const
{
return
"adjust_allocation"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/allocation_model.hpp
View file @
11e155c2
...
@@ -32,18 +32,22 @@ struct allocation_model
...
@@ -32,18 +32,22 @@ struct allocation_model
#else
#else
/*
#ifdef TYPE_ERASED_DECLARATION
* Type-erased interface for:
*
// Type-erased interface for:
* struct allocation_model
struct
allocation_model
* {
{
* std::string name() const;
//
* std::string copy() const;
std
::
string
name
()
const
;
* operation allocate(const shape& s) const;
//
* operation preallocate(const shape& s,std::string id) const;
std
::
string
copy
()
const
;
* };
//
*
operation
allocate
(
const
shape
&
s
)
const
;
*/
//
operation
preallocate
(
const
shape
&
s
,
std
::
string
id
)
const
;
};
#else
struct
allocation_model
struct
allocation_model
{
{
...
@@ -260,6 +264,7 @@ inline const ValueType& any_cast(const allocation_model& x)
...
@@ -260,6 +264,7 @@ inline const ValueType& any_cast(const allocation_model& x)
throw
std
::
bad_cast
();
throw
std
::
bad_cast
();
return
*
y
;
return
*
y
;
}
}
#endif
#endif
#endif
...
...
src/include/migraphx/analyze_streams.hpp
View file @
11e155c2
...
@@ -16,7 +16,7 @@ struct stream_race
...
@@ -16,7 +16,7 @@ struct stream_race
instruction_ref
before
;
instruction_ref
before
;
};
};
std
::
vector
<
stream_race
>
analyze_streams
(
const
module
&
p
,
const
stream_model
&
m
);
std
::
vector
<
stream_race
>
analyze_streams
(
const
module
&
m
,
const
stream_model
&
strm
m
);
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/include/migraphx/any_ptr.hpp
0 → 100644
View file @
11e155c2
#ifndef MIGRAPHX_GUARD_MIGRAPHX_ANY_PTR_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_ANY_PTR_HPP
#include <migraphx/config.hpp>
#include <migraphx/optional.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/type_name.hpp>
#include <cassert>
#include <string_view>
#include <typeindex>
#include <type_traits>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
any_ptr
{
any_ptr
()
=
default
;
template
<
class
T
>
any_ptr
(
T
*
p
)
:
ptr
(
p
),
ti
(
typeid
(
T
*
)),
name
(
get_name
<
T
*>
())
{
}
any_ptr
(
void
*
p
,
std
::
string_view
pname
)
:
ptr
(
p
),
name
(
pname
)
{}
void
*
get
(
std
::
string_view
n
)
const
{
if
(
name
!=
n
)
MIGRAPHX_THROW
(
"any_ptr: type mismatch: "
+
std
::
string
{
name
}
+
" != "
+
std
::
string
{
n
});
return
ptr
;
}
template
<
class
T
>
T
get
()
const
{
static_assert
(
std
::
is_pointer
<
T
>
{},
"Must be a pointer"
);
assert
(
ptr
!=
nullptr
);
if
(
ti
and
std
::
type_index
{
typeid
(
T
)}
!=
*
ti
)
MIGRAPHX_THROW
(
"any_ptr: type mismatch: "
+
std
::
string
{
name
}
+
" != "
+
get_name
<
T
>
());
else
if
(
name
!=
get_name
<
T
>
())
MIGRAPHX_THROW
(
"any_ptr: type mismatch: "
+
std
::
string
{
name
}
+
" != "
+
get_name
<
T
>
());
return
reinterpret_cast
<
T
>
(
ptr
);
}
void
*
unsafe_get
()
const
{
return
ptr
;
}
private:
void
*
ptr
=
nullptr
;
optional
<
std
::
type_index
>
ti
=
nullopt
;
std
::
string_view
name
=
""
;
template
<
class
T
>
static
const
std
::
string
&
get_name
()
{
return
get_type_name
<
std
::
remove_cv_t
<
std
::
remove_pointer_t
<
T
>>>
();
}
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_ANY_PTR_HPP
src/include/migraphx/auto_contiguous.hpp
View file @
11e155c2
...
@@ -13,7 +13,7 @@ struct module;
...
@@ -13,7 +13,7 @@ struct module;
struct
auto_contiguous
struct
auto_contiguous
{
{
std
::
string
name
()
const
{
return
"auto_contiguous"
;
}
std
::
string
name
()
const
{
return
"auto_contiguous"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
Prev
1
2
3
4
5
6
7
…
20
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment