Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
11e155c2
Commit
11e155c2
authored
Jun 13, 2022
by
Paul
Browse files
Merge
parents
8a9c5bce
aa7ff911
Changes
397
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
188 additions
and
96 deletions
+188
-96
src/cpp_generator.cpp
src/cpp_generator.cpp
+9
-1
src/dead_code_elimination.cpp
src/dead_code_elimination.cpp
+6
-22
src/driver/alexnet.cpp
src/driver/alexnet.cpp
+3
-3
src/driver/inceptionv3.cpp
src/driver/inceptionv3.cpp
+14
-14
src/driver/main.cpp
src/driver/main.cpp
+8
-0
src/driver/marker_roctx.cpp
src/driver/marker_roctx.cpp
+1
-1
src/driver/perf.cpp
src/driver/perf.cpp
+1
-1
src/driver/resnet50.cpp
src/driver/resnet50.cpp
+2
-2
src/eliminate_allocation.cpp
src/eliminate_allocation.cpp
+4
-4
src/eliminate_common_subexpression.cpp
src/eliminate_common_subexpression.cpp
+10
-5
src/eliminate_concat.cpp
src/eliminate_concat.cpp
+6
-6
src/eliminate_contiguous.cpp
src/eliminate_contiguous.cpp
+26
-11
src/eliminate_data_type.cpp
src/eliminate_data_type.cpp
+8
-2
src/eliminate_identity.cpp
src/eliminate_identity.cpp
+8
-8
src/eliminate_pad.cpp
src/eliminate_pad.cpp
+1
-1
src/include/migraphx/adjust_allocation.hpp
src/include/migraphx/adjust_allocation.hpp
+1
-1
src/include/migraphx/allocation_model.hpp
src/include/migraphx/allocation_model.hpp
+17
-12
src/include/migraphx/analyze_streams.hpp
src/include/migraphx/analyze_streams.hpp
+1
-1
src/include/migraphx/any_ptr.hpp
src/include/migraphx/any_ptr.hpp
+61
-0
src/include/migraphx/auto_contiguous.hpp
src/include/migraphx/auto_contiguous.hpp
+1
-1
No files found.
src/cpp_generator.cpp
View file @
11e155c2
...
@@ -88,6 +88,7 @@ struct cpp_generator_impl
...
@@ -88,6 +88,7 @@ struct cpp_generator_impl
std
::
stringstream
fs
{};
std
::
stringstream
fs
{};
std
::
size_t
function_count
=
0
;
std
::
size_t
function_count
=
0
;
std
::
function
<
std
::
string
(
std
::
string
)
>
fmap
=
nullptr
;
std
::
function
<
std
::
string
(
std
::
string
)
>
fmap
=
nullptr
;
std
::
function
<
std
::
string
(
shape
)
>
fresult
=
nullptr
;
std
::
unordered_map
<
std
::
string
,
std
::
string
>
point_op_map
=
{};
std
::
unordered_map
<
std
::
string
,
std
::
string
>
point_op_map
=
{};
};
};
cpp_generator
::
cpp_generator
()
:
impl
(
std
::
make_unique
<
cpp_generator_impl
>
())
{}
cpp_generator
::
cpp_generator
()
:
impl
(
std
::
make_unique
<
cpp_generator_impl
>
())
{}
...
@@ -104,6 +105,8 @@ cpp_generator::~cpp_generator() noexcept = default;
...
@@ -104,6 +105,8 @@ cpp_generator::~cpp_generator() noexcept = default;
void
cpp_generator
::
fmap
(
const
std
::
function
<
std
::
string
(
std
::
string
)
>&
f
)
{
impl
->
fmap
=
f
;
}
void
cpp_generator
::
fmap
(
const
std
::
function
<
std
::
string
(
std
::
string
)
>&
f
)
{
impl
->
fmap
=
f
;
}
void
cpp_generator
::
fresult
(
const
std
::
function
<
std
::
string
(
shape
)
>&
f
)
{
impl
->
fresult
=
f
;
}
void
cpp_generator
::
add_point_op
(
const
std
::
string
&
op_name
,
const
std
::
string
&
code
)
void
cpp_generator
::
add_point_op
(
const
std
::
string
&
op_name
,
const
std
::
string
&
code
)
{
{
impl
->
point_op_map
[
op_name
]
=
code
;
impl
->
point_op_map
[
op_name
]
=
code
;
...
@@ -174,7 +177,12 @@ cpp_generator::function cpp_generator::generate_module(const module& m)
...
@@ -174,7 +177,12 @@ cpp_generator::function cpp_generator::generate_module(const module& m)
ins
->
inputs
().
end
(),
ins
->
inputs
().
end
(),
std
::
back_inserter
(
args
),
std
::
back_inserter
(
args
),
[
&
](
auto
i
)
{
return
names
.
at
(
i
);
});
[
&
](
auto
i
)
{
return
names
.
at
(
i
);
});
return
this
->
generate_point_op
(
ins
->
get_operator
(),
args
);
auto
s
=
this
->
generate_point_op
(
ins
->
get_operator
(),
args
);
if
(
impl
->
fresult
)
return
impl
->
fresult
(
ins
->
get_shape
())
+
'('
+
s
+
')'
;
else
return
s
;
});
});
return
f
;
return
f
;
}
}
...
...
src/dead_code_elimination.cpp
View file @
11e155c2
...
@@ -9,26 +9,6 @@
...
@@ -9,26 +9,6 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
class
Range
,
class
Iterator
>
std
::
ptrdiff_t
bidistance
(
const
Range
&
r
,
Iterator
start
,
Iterator
last
)
{
auto
start_forward
=
start
;
auto
start_backwards
=
start
;
std
::
size_t
n
=
0
;
while
(
start_forward
!=
last
and
start_backwards
!=
last
)
{
n
++
;
if
(
start_forward
!=
r
.
end
())
start_forward
++
;
if
(
start_backwards
!=
r
.
begin
())
start_backwards
--
;
}
if
(
start_forward
==
last
)
return
n
;
else
return
-
n
;
}
void
dead_code_elimination
::
apply
(
program
&
p
)
const
{
p
.
remove_unused_modules
();
}
void
dead_code_elimination
::
apply
(
program
&
p
)
const
{
p
.
remove_unused_modules
();
}
void
dead_code_elimination
::
apply
(
module
&
m
)
const
void
dead_code_elimination
::
apply
(
module
&
m
)
const
...
@@ -48,17 +28,21 @@ void dead_code_elimination::apply(module& m) const
...
@@ -48,17 +28,21 @@ void dead_code_elimination::apply(module& m) const
if
(
i
->
get_shape
().
elements
()
==
0
and
i
->
name
().
front
()
!=
'@'
and
if
(
i
->
get_shape
().
elements
()
==
0
and
i
->
name
().
front
()
!=
'@'
and
i
->
name
()
!=
"undefined"
and
i
->
name
()
!=
"identity"
)
i
->
name
()
!=
"undefined"
and
i
->
name
()
!=
"identity"
)
continue
;
continue
;
assert
(
bidistance
(
m
,
i
,
last
)
>
0
);
assert
(
std
::
distance
(
m
.
begin
(),
i
)
<=
std
::
distance
(
m
.
begin
(),
last
));
std
::
unordered_set
<
instruction_ref
>
visited
;
fix
([
&
](
auto
self
,
auto
leaf
)
{
fix
([
&
](
auto
self
,
auto
leaf
)
{
if
(
not
m
.
has_instruction
(
leaf
))
if
(
not
m
.
has_instruction
(
leaf
))
return
;
return
;
if
(
leaf
->
outputs
().
empty
())
if
(
leaf
->
outputs
().
empty
())
{
{
// Dont visit inputs twice
if
(
not
visited
.
insert
(
leaf
).
second
)
return
;
std
::
unordered_set
<
instruction_ref
>
args
(
leaf
->
inputs
().
begin
(),
std
::
unordered_set
<
instruction_ref
>
args
(
leaf
->
inputs
().
begin
(),
leaf
->
inputs
().
end
());
leaf
->
inputs
().
end
());
leaf
->
clear_arguments
();
leaf
->
clear_arguments
();
assert
(
bi
distance
(
m
,
last
,
leaf
)
<
0
);
assert
(
std
::
distance
(
m
.
begin
(),
leaf
)
<
std
::
distance
(
m
.
begin
(),
last
)
);
assert
(
leaf
!=
ins
);
assert
(
leaf
!=
ins
);
if
(
leaf
->
name
()
!=
"@param"
)
if
(
leaf
->
name
()
!=
"@param"
)
m
.
move_instruction
(
leaf
,
m
.
end
());
m
.
move_instruction
(
leaf
,
m
.
end
());
...
...
src/driver/alexnet.cpp
View file @
11e155c2
...
@@ -61,7 +61,7 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
...
@@ -61,7 +61,7 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx
::
op
::
relu
relu19
;
migraphx
::
op
::
relu
relu19
;
auto
mx19
=
mm
->
add_instruction
(
relu19
,
mx18
);
auto
mx19
=
mm
->
add_instruction
(
relu19
,
mx18
);
migraphx
::
op
::
pooling
pooling20
;
migraphx
::
op
::
pooling
pooling20
;
pooling20
.
mode
=
"max"
;
pooling20
.
mode
=
migraphx
::
op
::
pooling_mode
::
max
;
pooling20
.
padding
=
{
0
,
0
};
pooling20
.
padding
=
{
0
,
0
};
pooling20
.
stride
=
{
2
,
2
};
pooling20
.
stride
=
{
2
,
2
};
pooling20
.
lengths
=
{
3
,
3
};
pooling20
.
lengths
=
{
3
,
3
};
...
@@ -81,7 +81,7 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
...
@@ -81,7 +81,7 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx
::
op
::
relu
relu24
;
migraphx
::
op
::
relu
relu24
;
auto
mx24
=
mm
->
add_instruction
(
relu24
,
mx23
);
auto
mx24
=
mm
->
add_instruction
(
relu24
,
mx23
);
migraphx
::
op
::
pooling
pooling25
;
migraphx
::
op
::
pooling
pooling25
;
pooling25
.
mode
=
"max"
;
pooling25
.
mode
=
migraphx
::
op
::
pooling_mode
::
max
;
pooling25
.
padding
=
{
0
,
0
};
pooling25
.
padding
=
{
0
,
0
};
pooling25
.
stride
=
{
2
,
2
};
pooling25
.
stride
=
{
2
,
2
};
pooling25
.
lengths
=
{
3
,
3
};
pooling25
.
lengths
=
{
3
,
3
};
...
@@ -129,7 +129,7 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
...
@@ -129,7 +129,7 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx
::
op
::
relu
relu37
;
migraphx
::
op
::
relu
relu37
;
auto
mx37
=
mm
->
add_instruction
(
relu37
,
mx36
);
auto
mx37
=
mm
->
add_instruction
(
relu37
,
mx36
);
migraphx
::
op
::
pooling
pooling38
;
migraphx
::
op
::
pooling
pooling38
;
pooling38
.
mode
=
"max"
;
pooling38
.
mode
=
migraphx
::
op
::
pooling_mode
::
max
;
pooling38
.
padding
=
{
0
,
0
};
pooling38
.
padding
=
{
0
,
0
};
pooling38
.
stride
=
{
2
,
2
};
pooling38
.
stride
=
{
2
,
2
};
pooling38
.
lengths
=
{
3
,
3
};
pooling38
.
lengths
=
{
3
,
3
};
...
...
src/driver/inceptionv3.cpp
View file @
11e155c2
...
@@ -995,7 +995,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
...
@@ -995,7 +995,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx
::
op
::
relu
relu492
;
migraphx
::
op
::
relu
relu492
;
auto
mx492
=
mm
->
add_instruction
(
relu492
,
mx491
);
auto
mx492
=
mm
->
add_instruction
(
relu492
,
mx491
);
migraphx
::
op
::
pooling
pooling493
;
migraphx
::
op
::
pooling
pooling493
;
pooling493
.
mode
=
"max"
;
pooling493
.
mode
=
migraphx
::
op
::
pooling_mode
::
max
;
pooling493
.
padding
=
{
0
,
0
};
pooling493
.
padding
=
{
0
,
0
};
pooling493
.
stride
=
{
2
,
2
};
pooling493
.
stride
=
{
2
,
2
};
pooling493
.
lengths
=
{
3
,
3
};
pooling493
.
lengths
=
{
3
,
3
};
...
@@ -1025,7 +1025,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
...
@@ -1025,7 +1025,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx
::
op
::
relu
relu499
;
migraphx
::
op
::
relu
relu499
;
auto
mx499
=
mm
->
add_instruction
(
relu499
,
mx498
);
auto
mx499
=
mm
->
add_instruction
(
relu499
,
mx498
);
migraphx
::
op
::
pooling
pooling500
;
migraphx
::
op
::
pooling
pooling500
;
pooling500
.
mode
=
"max"
;
pooling500
.
mode
=
migraphx
::
op
::
pooling_mode
::
max
;
pooling500
.
padding
=
{
0
,
0
};
pooling500
.
padding
=
{
0
,
0
};
pooling500
.
stride
=
{
2
,
2
};
pooling500
.
stride
=
{
2
,
2
};
pooling500
.
lengths
=
{
3
,
3
};
pooling500
.
lengths
=
{
3
,
3
};
...
@@ -1103,7 +1103,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
...
@@ -1103,7 +1103,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx
::
op
::
relu
relu518
;
migraphx
::
op
::
relu
relu518
;
auto
mx518
=
mm
->
add_instruction
(
relu518
,
mx517
);
auto
mx518
=
mm
->
add_instruction
(
relu518
,
mx517
);
migraphx
::
op
::
pooling
pooling519
;
migraphx
::
op
::
pooling
pooling519
;
pooling519
.
mode
=
"
average
"
;
pooling519
.
mode
=
migraphx
::
op
::
pooling_mode
::
average
;
pooling519
.
padding
=
{
1
,
1
};
pooling519
.
padding
=
{
1
,
1
};
pooling519
.
stride
=
{
1
,
1
};
pooling519
.
stride
=
{
1
,
1
};
pooling519
.
lengths
=
{
3
,
3
};
pooling519
.
lengths
=
{
3
,
3
};
...
@@ -1196,7 +1196,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
...
@@ -1196,7 +1196,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx
::
op
::
relu
relu541
;
migraphx
::
op
::
relu
relu541
;
auto
mx541
=
mm
->
add_instruction
(
relu541
,
mx540
);
auto
mx541
=
mm
->
add_instruction
(
relu541
,
mx540
);
migraphx
::
op
::
pooling
pooling542
;
migraphx
::
op
::
pooling
pooling542
;
pooling542
.
mode
=
"
average
"
;
pooling542
.
mode
=
migraphx
::
op
::
pooling_mode
::
average
;
pooling542
.
padding
=
{
1
,
1
};
pooling542
.
padding
=
{
1
,
1
};
pooling542
.
stride
=
{
1
,
1
};
pooling542
.
stride
=
{
1
,
1
};
pooling542
.
lengths
=
{
3
,
3
};
pooling542
.
lengths
=
{
3
,
3
};
...
@@ -1289,7 +1289,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
...
@@ -1289,7 +1289,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx
::
op
::
relu
relu564
;
migraphx
::
op
::
relu
relu564
;
auto
mx564
=
mm
->
add_instruction
(
relu564
,
mx563
);
auto
mx564
=
mm
->
add_instruction
(
relu564
,
mx563
);
migraphx
::
op
::
pooling
pooling565
;
migraphx
::
op
::
pooling
pooling565
;
pooling565
.
mode
=
"
average
"
;
pooling565
.
mode
=
migraphx
::
op
::
pooling_mode
::
average
;
pooling565
.
padding
=
{
1
,
1
};
pooling565
.
padding
=
{
1
,
1
};
pooling565
.
stride
=
{
1
,
1
};
pooling565
.
stride
=
{
1
,
1
};
pooling565
.
lengths
=
{
3
,
3
};
pooling565
.
lengths
=
{
3
,
3
};
...
@@ -1358,7 +1358,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
...
@@ -1358,7 +1358,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx
::
op
::
relu
relu581
;
migraphx
::
op
::
relu
relu581
;
auto
mx581
=
mm
->
add_instruction
(
relu581
,
mx580
);
auto
mx581
=
mm
->
add_instruction
(
relu581
,
mx580
);
migraphx
::
op
::
pooling
pooling582
;
migraphx
::
op
::
pooling
pooling582
;
pooling582
.
mode
=
"max"
;
pooling582
.
mode
=
migraphx
::
op
::
pooling_mode
::
max
;
pooling582
.
padding
=
{
0
,
0
};
pooling582
.
padding
=
{
0
,
0
};
pooling582
.
stride
=
{
2
,
2
};
pooling582
.
stride
=
{
2
,
2
};
pooling582
.
lengths
=
{
3
,
3
};
pooling582
.
lengths
=
{
3
,
3
};
...
@@ -1475,7 +1475,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
...
@@ -1475,7 +1475,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx
::
op
::
relu
relu610
;
migraphx
::
op
::
relu
relu610
;
auto
mx610
=
mm
->
add_instruction
(
relu610
,
mx609
);
auto
mx610
=
mm
->
add_instruction
(
relu610
,
mx609
);
migraphx
::
op
::
pooling
pooling611
;
migraphx
::
op
::
pooling
pooling611
;
pooling611
.
mode
=
"
average
"
;
pooling611
.
mode
=
migraphx
::
op
::
pooling_mode
::
average
;
pooling611
.
padding
=
{
1
,
1
};
pooling611
.
padding
=
{
1
,
1
};
pooling611
.
stride
=
{
1
,
1
};
pooling611
.
stride
=
{
1
,
1
};
pooling611
.
lengths
=
{
3
,
3
};
pooling611
.
lengths
=
{
3
,
3
};
...
@@ -1604,7 +1604,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
...
@@ -1604,7 +1604,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx
::
op
::
relu
relu642
;
migraphx
::
op
::
relu
relu642
;
auto
mx642
=
mm
->
add_instruction
(
relu642
,
mx641
);
auto
mx642
=
mm
->
add_instruction
(
relu642
,
mx641
);
migraphx
::
op
::
pooling
pooling643
;
migraphx
::
op
::
pooling
pooling643
;
pooling643
.
mode
=
"
average
"
;
pooling643
.
mode
=
migraphx
::
op
::
pooling_mode
::
average
;
pooling643
.
padding
=
{
1
,
1
};
pooling643
.
padding
=
{
1
,
1
};
pooling643
.
stride
=
{
1
,
1
};
pooling643
.
stride
=
{
1
,
1
};
pooling643
.
lengths
=
{
3
,
3
};
pooling643
.
lengths
=
{
3
,
3
};
...
@@ -1733,7 +1733,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
...
@@ -1733,7 +1733,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx
::
op
::
relu
relu674
;
migraphx
::
op
::
relu
relu674
;
auto
mx674
=
mm
->
add_instruction
(
relu674
,
mx673
);
auto
mx674
=
mm
->
add_instruction
(
relu674
,
mx673
);
migraphx
::
op
::
pooling
pooling675
;
migraphx
::
op
::
pooling
pooling675
;
pooling675
.
mode
=
"
average
"
;
pooling675
.
mode
=
migraphx
::
op
::
pooling_mode
::
average
;
pooling675
.
padding
=
{
1
,
1
};
pooling675
.
padding
=
{
1
,
1
};
pooling675
.
stride
=
{
1
,
1
};
pooling675
.
stride
=
{
1
,
1
};
pooling675
.
lengths
=
{
3
,
3
};
pooling675
.
lengths
=
{
3
,
3
};
...
@@ -1862,7 +1862,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
...
@@ -1862,7 +1862,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx
::
op
::
relu
relu706
;
migraphx
::
op
::
relu
relu706
;
auto
mx706
=
mm
->
add_instruction
(
relu706
,
mx705
);
auto
mx706
=
mm
->
add_instruction
(
relu706
,
mx705
);
migraphx
::
op
::
pooling
pooling707
;
migraphx
::
op
::
pooling
pooling707
;
pooling707
.
mode
=
"
average
"
;
pooling707
.
mode
=
migraphx
::
op
::
pooling_mode
::
average
;
pooling707
.
padding
=
{
1
,
1
};
pooling707
.
padding
=
{
1
,
1
};
pooling707
.
stride
=
{
1
,
1
};
pooling707
.
stride
=
{
1
,
1
};
pooling707
.
lengths
=
{
3
,
3
};
pooling707
.
lengths
=
{
3
,
3
};
...
@@ -1955,7 +1955,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
...
@@ -1955,7 +1955,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx
::
op
::
relu
relu729
;
migraphx
::
op
::
relu
relu729
;
auto
mx729
=
mm
->
add_instruction
(
relu729
,
mx728
);
auto
mx729
=
mm
->
add_instruction
(
relu729
,
mx728
);
migraphx
::
op
::
pooling
pooling730
;
migraphx
::
op
::
pooling
pooling730
;
pooling730
.
mode
=
"max"
;
pooling730
.
mode
=
migraphx
::
op
::
pooling_mode
::
max
;
pooling730
.
padding
=
{
0
,
0
};
pooling730
.
padding
=
{
0
,
0
};
pooling730
.
stride
=
{
2
,
2
};
pooling730
.
stride
=
{
2
,
2
};
pooling730
.
lengths
=
{
3
,
3
};
pooling730
.
lengths
=
{
3
,
3
};
...
@@ -2066,7 +2066,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
...
@@ -2066,7 +2066,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
concat757
.
axis
=
1
;
concat757
.
axis
=
1
;
auto
mx757
=
mm
->
add_instruction
(
concat757
,
mx753
,
mx756
);
auto
mx757
=
mm
->
add_instruction
(
concat757
,
mx753
,
mx756
);
migraphx
::
op
::
pooling
pooling758
;
migraphx
::
op
::
pooling
pooling758
;
pooling758
.
mode
=
"
average
"
;
pooling758
.
mode
=
migraphx
::
op
::
pooling_mode
::
average
;
pooling758
.
padding
=
{
1
,
1
};
pooling758
.
padding
=
{
1
,
1
};
pooling758
.
stride
=
{
1
,
1
};
pooling758
.
stride
=
{
1
,
1
};
pooling758
.
lengths
=
{
3
,
3
};
pooling758
.
lengths
=
{
3
,
3
};
...
@@ -2189,7 +2189,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
...
@@ -2189,7 +2189,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
concat788
.
axis
=
1
;
concat788
.
axis
=
1
;
auto
mx788
=
mm
->
add_instruction
(
concat788
,
mx784
,
mx787
);
auto
mx788
=
mm
->
add_instruction
(
concat788
,
mx784
,
mx787
);
migraphx
::
op
::
pooling
pooling789
;
migraphx
::
op
::
pooling
pooling789
;
pooling789
.
mode
=
"
average
"
;
pooling789
.
mode
=
migraphx
::
op
::
pooling_mode
::
average
;
pooling789
.
padding
=
{
1
,
1
};
pooling789
.
padding
=
{
1
,
1
};
pooling789
.
stride
=
{
1
,
1
};
pooling789
.
stride
=
{
1
,
1
};
pooling789
.
lengths
=
{
3
,
3
};
pooling789
.
lengths
=
{
3
,
3
};
...
@@ -2210,7 +2210,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
...
@@ -2210,7 +2210,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
concat793
.
axis
=
1
;
concat793
.
axis
=
1
;
auto
mx793
=
mm
->
add_instruction
(
concat793
,
mx765
,
mx775
,
mx788
,
mx792
);
auto
mx793
=
mm
->
add_instruction
(
concat793
,
mx765
,
mx775
,
mx788
,
mx792
);
migraphx
::
op
::
pooling
pooling794
;
migraphx
::
op
::
pooling
pooling794
;
pooling794
.
mode
=
"
average
"
;
pooling794
.
mode
=
migraphx
::
op
::
pooling_mode
::
average
;
pooling794
.
padding
=
{
0
,
0
};
pooling794
.
padding
=
{
0
,
0
};
pooling794
.
stride
=
{
8
,
8
};
pooling794
.
stride
=
{
8
,
8
};
pooling794
.
lengths
=
{
8
,
8
};
pooling794
.
lengths
=
{
8
,
8
};
...
...
src/driver/main.cpp
View file @
11e155c2
...
@@ -508,8 +508,10 @@ struct roctx : command<roctx>
...
@@ -508,8 +508,10 @@ struct roctx : command<roctx>
struct
op
:
command
<
op
>
struct
op
:
command
<
op
>
{
{
bool
show_ops
=
false
;
bool
show_ops
=
false
;
std
::
string
op_name
{};
void
parse
(
argument_parser
&
ap
)
void
parse
(
argument_parser
&
ap
)
{
{
ap
(
op_name
,
{},
ap
.
metavar
(
"<MIGraphX operator name>"
));
ap
(
show_ops
,
ap
(
show_ops
,
{
"--list"
,
"-l"
},
{
"--list"
,
"-l"
},
ap
.
help
(
"List all the operators of MIGraphX"
),
ap
.
help
(
"List all the operators of MIGraphX"
),
...
@@ -522,6 +524,12 @@ struct op : command<op>
...
@@ -522,6 +524,12 @@ struct op : command<op>
for
(
const
auto
&
name
:
get_operators
())
for
(
const
auto
&
name
:
get_operators
())
std
::
cout
<<
name
<<
std
::
endl
;
std
::
cout
<<
name
<<
std
::
endl
;
}
}
else
{
auto
op
=
load_op
(
op_name
);
std
::
cout
<<
op_name
<<
": "
<<
std
::
endl
;
std
::
cout
<<
to_pretty_json_string
(
op
.
to_value
())
<<
std
::
endl
;
}
}
}
};
};
...
...
src/driver/marker_roctx.cpp
View file @
11e155c2
...
@@ -17,7 +17,7 @@ class marker_roctx
...
@@ -17,7 +17,7 @@ class marker_roctx
std
::
function
<
int
(
const
char
*
)
>
sym_roctx_range_push
;
std
::
function
<
int
(
const
char
*
)
>
sym_roctx_range_push
;
std
::
function
<
int
()
>
sym_roctx_range_pop
;
std
::
function
<
int
()
>
sym_roctx_range_pop
;
uint64_t
range_id
;
uint64_t
range_id
=
0
;
public:
public:
marker_roctx
()
marker_roctx
()
...
...
src/driver/perf.cpp
View file @
11e155c2
...
@@ -87,6 +87,6 @@ target get_target(bool gpu)
...
@@ -87,6 +87,6 @@ target get_target(bool gpu)
void
compile_program
(
program
&
p
,
bool
gpu
)
{
p
.
compile
(
get_target
(
gpu
));
}
void
compile_program
(
program
&
p
,
bool
gpu
)
{
p
.
compile
(
get_target
(
gpu
));
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace
MIGRAPHX_INLINE_NS
}
// namespace driver
}
// namespace driver
}
// namespace migraphx
}
// namespace migraphx
src/driver/resnet50.cpp
View file @
11e155c2
...
@@ -561,7 +561,7 @@ migraphx::program resnet50(unsigned batch) // NOLINT(readability-function-size)
...
@@ -561,7 +561,7 @@ migraphx::program resnet50(unsigned batch) // NOLINT(readability-function-size)
migraphx
::
op
::
relu
relu269
;
migraphx
::
op
::
relu
relu269
;
auto
mx269
=
mm
->
add_instruction
(
relu269
,
mx268
);
auto
mx269
=
mm
->
add_instruction
(
relu269
,
mx268
);
migraphx
::
op
::
pooling
pooling270
;
migraphx
::
op
::
pooling
pooling270
;
pooling270
.
mode
=
"max"
;
pooling270
.
mode
=
migraphx
::
op
::
pooling_mode
::
max
;
pooling270
.
padding
=
{
1
,
1
};
pooling270
.
padding
=
{
1
,
1
};
pooling270
.
stride
=
{
2
,
2
};
pooling270
.
stride
=
{
2
,
2
};
pooling270
.
lengths
=
{
3
,
3
};
pooling270
.
lengths
=
{
3
,
3
};
...
@@ -1215,7 +1215,7 @@ migraphx::program resnet50(unsigned batch) // NOLINT(readability-function-size)
...
@@ -1215,7 +1215,7 @@ migraphx::program resnet50(unsigned batch) // NOLINT(readability-function-size)
migraphx
::
op
::
relu
relu438
;
migraphx
::
op
::
relu
relu438
;
auto
mx438
=
mm
->
add_instruction
(
relu438
,
mx437
);
auto
mx438
=
mm
->
add_instruction
(
relu438
,
mx437
);
migraphx
::
op
::
pooling
pooling439
;
migraphx
::
op
::
pooling
pooling439
;
pooling439
.
mode
=
"
average
"
;
pooling439
.
mode
=
migraphx
::
op
::
pooling_mode
::
average
;
pooling439
.
padding
=
{
0
,
0
};
pooling439
.
padding
=
{
0
,
0
};
pooling439
.
stride
=
{
1
,
1
};
pooling439
.
stride
=
{
1
,
1
};
pooling439
.
lengths
=
{
7
,
7
};
pooling439
.
lengths
=
{
7
,
7
};
...
...
src/eliminate_allocation.cpp
View file @
11e155c2
...
@@ -13,13 +13,13 @@
...
@@ -13,13 +13,13 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
eliminate_allocation
::
apply
(
module
&
p
)
const
void
eliminate_allocation
::
apply
(
module
&
m
)
const
{
{
assert
(
alignment
>
0
);
assert
(
alignment
>
0
);
std
::
size_t
n
=
0
;
std
::
size_t
n
=
0
;
std
::
vector
<
std
::
pair
<
instruction_ref
,
std
::
size_t
>>
allocs
;
std
::
vector
<
std
::
pair
<
instruction_ref
,
std
::
size_t
>>
allocs
;
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
m
))
{
{
if
(
ins
->
name
()
!=
allocation_op
)
if
(
ins
->
name
()
!=
allocation_op
)
continue
;
continue
;
...
@@ -30,13 +30,13 @@ void eliminate_allocation::apply(module& p) const
...
@@ -30,13 +30,13 @@ void eliminate_allocation::apply(module& p) const
}
}
if
(
n
>
0
)
if
(
n
>
0
)
{
{
auto
mem
=
p
.
add_parameter
(
"memory"
,
shape
{
shape
::
int8_type
,
{
n
}});
auto
mem
=
m
.
add_parameter
(
"memory"
,
shape
{
shape
::
int8_type
,
{
n
}});
for
(
auto
&&
pp
:
allocs
)
for
(
auto
&&
pp
:
allocs
)
{
{
auto
ins
=
pp
.
first
;
auto
ins
=
pp
.
first
;
auto
s
=
ins
->
get_shape
();
auto
s
=
ins
->
get_shape
();
auto
offset
=
pp
.
second
;
auto
offset
=
pp
.
second
;
p
.
replace_instruction
(
m
.
replace_instruction
(
ins
,
make_op
(
"load"
,
{{
"shape"
,
to_value
(
s
)},
{
"offset"
,
offset
}}),
mem
);
ins
,
make_op
(
"load"
,
{{
"shape"
,
to_value
(
s
)},
{
"offset"
,
offset
}}),
mem
);
}
}
}
}
...
...
src/eliminate_common_subexpression.cpp
View file @
11e155c2
...
@@ -11,7 +11,7 @@ namespace migraphx {
...
@@ -11,7 +11,7 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
class
Range
>
template
<
class
Range
>
void
cse_range
(
module
&
p
,
Range
&&
r
)
void
cse_range
(
module
&
m
,
Range
&&
r
)
{
{
std
::
unordered_multimap
<
std
::
string
,
instruction_ref
>
instructions
;
std
::
unordered_multimap
<
std
::
string
,
instruction_ref
>
instructions
;
std
::
unordered_set
<
instruction_ref
>
processed_ins
;
std
::
unordered_set
<
instruction_ref
>
processed_ins
;
...
@@ -30,19 +30,24 @@ void cse_range(module& p, Range&& r)
...
@@ -30,19 +30,24 @@ void cse_range(module& p, Range&& r)
continue
;
continue
;
if
(
*
eq
!=
*
ins
)
if
(
*
eq
!=
*
ins
)
continue
;
continue
;
p
.
replace_instruction
(
ins
,
eq
);
m
.
replace_instruction
(
ins
,
eq
);
processed_ins
.
emplace
(
ins
);
processed_ins
.
emplace
(
ins
);
auto
outputs
=
eq
->
outputs
();
std
::
vector
<
instruction_ref
>
outputs
;
std
::
copy_if
(
eq
->
outputs
().
begin
(),
eq
->
outputs
().
end
(),
std
::
back_inserter
(
outputs
),
[
&
](
auto
x
)
{
return
m
.
has_instruction
(
x
);
});
std
::
sort
(
outputs
.
begin
(),
outputs
.
end
(),
[
&
](
auto
x
,
auto
y
)
{
std
::
sort
(
outputs
.
begin
(),
outputs
.
end
(),
[
&
](
auto
x
,
auto
y
)
{
return
std
::
distance
(
eq
,
x
)
<
std
::
distance
(
eq
,
y
);
return
std
::
distance
(
eq
,
x
)
<
std
::
distance
(
eq
,
y
);
});
});
cse_range
(
p
,
outputs
);
cse_range
(
m
,
outputs
);
}
}
instructions
.
emplace
(
ins
->
name
(),
ins
);
instructions
.
emplace
(
ins
->
name
(),
ins
);
}
}
}
}
void
eliminate_common_subexpression
::
apply
(
module
&
p
)
const
{
cse_range
(
p
,
iterator_for
(
p
));
}
void
eliminate_common_subexpression
::
apply
(
module
&
m
)
const
{
cse_range
(
m
,
iterator_for
(
m
));
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
src/eliminate_concat.cpp
View file @
11e155c2
...
@@ -13,9 +13,9 @@
...
@@ -13,9 +13,9 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
eliminate_concat
::
apply
(
module
&
p
)
const
void
eliminate_concat
::
apply
(
module
&
m
)
const
{
{
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
m
))
{
{
// Look for the concat operator
// Look for the concat operator
if
(
ins
->
name
()
!=
concat_opt
.
name
())
if
(
ins
->
name
()
!=
concat_opt
.
name
())
...
@@ -64,22 +64,22 @@ void eliminate_concat::apply(module& p) const
...
@@ -64,22 +64,22 @@ void eliminate_concat::apply(module& p) const
std
::
sort
(
sorted_allocations
.
begin
(),
std
::
sort
(
sorted_allocations
.
begin
(),
sorted_allocations
.
end
(),
sorted_allocations
.
end
(),
[
&
](
instruction_ref
x
,
instruction_ref
y
)
{
[
&
](
instruction_ref
x
,
instruction_ref
y
)
{
return
std
::
distance
(
p
.
begin
(),
x
)
<
std
::
distance
(
p
.
begin
(),
y
);
return
std
::
distance
(
m
.
begin
(),
x
)
<
std
::
distance
(
m
.
begin
(),
y
);
});
});
// Move "super" allocation to the front
// Move "super" allocation to the front
auto
first
=
sorted_allocations
.
front
();
auto
first
=
sorted_allocations
.
front
();
auto
super
=
p
.
move_instruction
(
last
,
first
);
auto
super
=
m
.
move_instruction
(
last
,
first
);
// Replace each allocation with a load
// Replace each allocation with a load
std
::
size_t
offset
=
0
;
std
::
size_t
offset
=
0
;
for
(
auto
alloc
:
allocations
)
for
(
auto
alloc
:
allocations
)
{
{
op
::
load
op
{
alloc
->
get_shape
(),
offset
};
op
::
load
op
{
alloc
->
get_shape
(),
offset
};
p
.
replace_instruction
(
alloc
,
op
,
{
super
});
m
.
replace_instruction
(
alloc
,
op
,
{
super
});
offset
+=
alloc
->
get_shape
().
bytes
();
offset
+=
alloc
->
get_shape
().
bytes
();
}
}
std
::
vector
<
instruction_ref
>
args
=
{
super
};
std
::
vector
<
instruction_ref
>
args
=
{
super
};
std
::
copy
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
()
-
1
,
std
::
back_inserter
(
args
));
std
::
copy
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
()
-
1
,
std
::
back_inserter
(
args
));
p
.
replace_instruction
(
ins
,
migraphx
::
make_op
(
"identity"
),
args
);
m
.
replace_instruction
(
ins
,
migraphx
::
make_op
(
"identity"
),
args
);
}
}
}
}
}
}
...
...
src/eliminate_contiguous.cpp
View file @
11e155c2
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/op/contiguous.hpp>
#include <migraphx/op/contiguous.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/par_for.hpp>
#include <utility>
#include <utility>
namespace
migraphx
{
namespace
migraphx
{
...
@@ -69,38 +70,52 @@ static bool try_compute_shape(instruction_ref ins,
...
@@ -69,38 +70,52 @@ static bool try_compute_shape(instruction_ref ins,
return
try_compute_shape
(
ins
,
inputs
,
mods
);
return
try_compute_shape
(
ins
,
inputs
,
mods
);
}
}
void
eliminate_contiguous
::
apply
(
module
&
p
)
const
void
eliminate_contiguous
::
apply
(
module
&
m
)
const
{
{
for
(
auto
ins
:
iterator_for
(
p
))
std
::
vector
<
instruction_ref
>
const_instruction
;
for
(
auto
ins
:
iterator_for
(
m
))
{
{
// return instruction should have inputs with standard shape
// return instruction should have inputs with standard shape
if
(
ins
->
name
()
==
"@return"
)
if
(
ins
->
name
()
==
"@return"
)
continue
;
continue
;
// Make a copy so we can modify it while we iterate
// Make a copy so we can modify it while we iterate
auto
args
=
ins
->
inputs
();
auto
args
=
ins
->
inputs
();
auto
new_args
=
args
;
auto
mod_args
=
ins
->
module_inputs
();
for
(
auto
arg
:
ins
->
inputs
())
for
(
auto
arg
:
ins
->
inputs
())
{
{
if
(
arg
->
name
()
==
op_name
)
if
(
arg
->
name
()
==
op_name
)
{
{
auto
new_args
=
args
;
auto
prev
=
arg
->
inputs
().
front
();
auto
prev
=
arg
->
inputs
().
front
();
replace
(
new_args
,
arg
,
prev
);
replace
(
new_args
,
arg
,
prev
);
if
(
try_compute_shape
(
ins
,
new_args
,
ins
->
module_inputs
()
))
if
(
try_compute_shape
(
ins
,
new_args
,
mod_args
))
{
{
instruction
::
replace_argument
(
ins
,
arg
,
prev
);
instruction
::
replace_argument
(
ins
,
arg
,
prev
);
}
}
else
if
(
prev
->
can_eval
())
else
if
(
prev
->
can_eval
())
{
{
auto
c
=
op
::
contiguous
{};
const_instruction
.
push_back
(
arg
);
auto
r
=
c
.
compute
(
c
.
compute_shape
({
prev
->
get_shape
()}),
{
prev
->
eval
()});
auto
l
=
p
.
add_literal
(
r
.
get_shape
(),
r
.
data
());
p
.
replace_instruction
(
arg
,
l
);
}
}
}
}
}
}
}
}
// Perform evaluations in parallel
std
::
vector
<
argument
>
literals
(
const_instruction
.
size
());
par_for
(
const_instruction
.
size
(),
1
,
[
&
](
const
auto
i
)
{
auto
c
=
op
::
contiguous
{};
auto
prev
=
const_instruction
[
i
]
->
inputs
().
front
();
literals
[
i
]
=
c
.
compute
(
c
.
compute_shape
({
prev
->
get_shape
()}),
{
prev
->
eval
()});
});
for
(
size_t
i
=
0
;
i
<
const_instruction
.
size
();
i
++
)
{
auto
l
=
m
.
add_literal
(
literals
[
i
].
get_shape
(),
literals
[
i
].
data
());
m
.
replace_instruction
(
const_instruction
[
i
],
l
);
}
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/eliminate_data_type.cpp
View file @
11e155c2
...
@@ -10,8 +10,14 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -10,8 +10,14 @@ inline namespace MIGRAPHX_INLINE_NS {
void
eliminate_data_type
::
apply
(
module
&
m
)
const
void
eliminate_data_type
::
apply
(
module
&
m
)
const
{
{
static
const
std
::
vector
<
std
::
string
>
skip_op_names
=
{
static
const
std
::
vector
<
std
::
string
>
skip_op_names
=
{
"convert"
,
"convert"
,
"get_tuple_elem"
,
"if"
,
"loop"
,
"roialign"
};
"get_tuple_elem"
,
"if"
,
"loop"
,
"roialign"
,
"scatternd_add"
,
"scatternd_mul"
,
"scatternd_none"
};
for
(
auto
ins
:
iterator_for
(
m
))
for
(
auto
ins
:
iterator_for
(
m
))
{
{
if
(
ins
->
name
()[
0
]
==
'@'
)
if
(
ins
->
name
()[
0
]
==
'@'
)
...
...
src/eliminate_identity.cpp
View file @
11e155c2
...
@@ -8,21 +8,21 @@
...
@@ -8,21 +8,21 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
eliminate_identity
::
apply
(
module
&
p
)
const
void
eliminate_identity
::
apply
(
module
&
m
)
const
{
{
auto
last
=
std
::
prev
(
p
.
end
());
auto
last
=
std
::
prev
(
m
.
end
());
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
m
))
{
{
// Skip the first instruction, since we always process the previous
// Skip the first instruction, since we always process the previous
// instruction
// instruction
if
(
ins
==
p
.
begin
())
if
(
ins
==
m
.
begin
())
continue
;
continue
;
const
auto
i
=
std
::
prev
(
ins
);
const
auto
i
=
std
::
prev
(
ins
);
if
(
i
->
name
()
==
"identity"
)
if
(
i
->
name
()
==
"identity"
)
{
{
p
.
replace_instruction
(
i
,
i
->
inputs
().
front
());
m
.
replace_instruction
(
i
,
i
->
inputs
().
front
());
p
.
move_instruction
(
i
,
p
.
end
());
m
.
move_instruction
(
i
,
m
.
end
());
}
}
if
(
ins
==
last
)
if
(
ins
==
last
)
{
{
...
@@ -31,7 +31,7 @@ void eliminate_identity::apply(module& p) const
...
@@ -31,7 +31,7 @@ void eliminate_identity::apply(module& p) const
const
instruction_ref
&
identity_input
=
ins
->
inputs
().
front
();
const
instruction_ref
&
identity_input
=
ins
->
inputs
().
front
();
if
(
identity_input
->
outputs
().
size
()
==
1
)
if
(
identity_input
->
outputs
().
size
()
==
1
)
{
{
p
.
move_instruction
(
identity_input
,
i
);
m
.
move_instruction
(
identity_input
,
i
);
// since this is the last instruction, removing it only
// since this is the last instruction, removing it only
// requires changing "last" and calling remove below
// requires changing "last" and calling remove below
last
=
std
::
prev
(
last
);
last
=
std
::
prev
(
last
);
...
@@ -40,7 +40,7 @@ void eliminate_identity::apply(module& p) const
...
@@ -40,7 +40,7 @@ void eliminate_identity::apply(module& p) const
break
;
break
;
}
}
}
}
p
.
remove_instructions
(
std
::
next
(
last
),
p
.
end
());
m
.
remove_instructions
(
std
::
next
(
last
),
m
.
end
());
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/eliminate_pad.cpp
View file @
11e155c2
...
@@ -44,7 +44,7 @@ static void update_op(const instruction_ref& input, const instruction_ref& ins,
...
@@ -44,7 +44,7 @@ static void update_op(const instruction_ref& input, const instruction_ref& ins,
static
void
update_pooling
(
const
instruction_ref
&
input
,
const
instruction_ref
&
ins
,
module
&
m
)
static
void
update_pooling
(
const
instruction_ref
&
input
,
const
instruction_ref
&
ins
,
module
&
m
)
{
{
auto
op
=
any_cast
<
op
::
pooling
>
(
ins
->
get_operator
());
auto
op
=
any_cast
<
op
::
pooling
>
(
ins
->
get_operator
());
if
(
op
.
mode
==
"
average
"
)
if
(
op
.
mode
==
op
::
pooling_mode
::
average
)
{
{
return
;
return
;
}
}
...
...
src/include/migraphx/adjust_allocation.hpp
View file @
11e155c2
...
@@ -13,7 +13,7 @@ struct adjust_allocation
...
@@ -13,7 +13,7 @@ struct adjust_allocation
{
{
allocation_model
model
;
allocation_model
model
;
std
::
string
name
()
const
{
return
"adjust_allocation"
;
}
std
::
string
name
()
const
{
return
"adjust_allocation"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/allocation_model.hpp
View file @
11e155c2
...
@@ -32,18 +32,22 @@ struct allocation_model
...
@@ -32,18 +32,22 @@ struct allocation_model
#else
#else
/*
#ifdef TYPE_ERASED_DECLARATION
* Type-erased interface for:
*
// Type-erased interface for:
* struct allocation_model
struct
allocation_model
* {
{
* std::string name() const;
//
* std::string copy() const;
std
::
string
name
()
const
;
* operation allocate(const shape& s) const;
//
* operation preallocate(const shape& s,std::string id) const;
std
::
string
copy
()
const
;
* };
//
*
operation
allocate
(
const
shape
&
s
)
const
;
*/
//
operation
preallocate
(
const
shape
&
s
,
std
::
string
id
)
const
;
};
#else
struct
allocation_model
struct
allocation_model
{
{
...
@@ -260,6 +264,7 @@ inline const ValueType& any_cast(const allocation_model& x)
...
@@ -260,6 +264,7 @@ inline const ValueType& any_cast(const allocation_model& x)
throw
std
::
bad_cast
();
throw
std
::
bad_cast
();
return
*
y
;
return
*
y
;
}
}
#endif
#endif
#endif
...
...
src/include/migraphx/analyze_streams.hpp
View file @
11e155c2
...
@@ -16,7 +16,7 @@ struct stream_race
...
@@ -16,7 +16,7 @@ struct stream_race
instruction_ref
before
;
instruction_ref
before
;
};
};
std
::
vector
<
stream_race
>
analyze_streams
(
const
module
&
p
,
const
stream_model
&
m
);
std
::
vector
<
stream_race
>
analyze_streams
(
const
module
&
m
,
const
stream_model
&
strm
m
);
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/include/migraphx/any_ptr.hpp
0 → 100644
View file @
11e155c2
#ifndef MIGRAPHX_GUARD_MIGRAPHX_ANY_PTR_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_ANY_PTR_HPP
#include <migraphx/config.hpp>
#include <migraphx/optional.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/type_name.hpp>
#include <cassert>
#include <string_view>
#include <typeindex>
#include <type_traits>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
any_ptr
{
any_ptr
()
=
default
;
template
<
class
T
>
any_ptr
(
T
*
p
)
:
ptr
(
p
),
ti
(
typeid
(
T
*
)),
name
(
get_name
<
T
*>
())
{
}
any_ptr
(
void
*
p
,
std
::
string_view
pname
)
:
ptr
(
p
),
name
(
pname
)
{}
void
*
get
(
std
::
string_view
n
)
const
{
if
(
name
!=
n
)
MIGRAPHX_THROW
(
"any_ptr: type mismatch: "
+
std
::
string
{
name
}
+
" != "
+
std
::
string
{
n
});
return
ptr
;
}
template
<
class
T
>
T
get
()
const
{
static_assert
(
std
::
is_pointer
<
T
>
{},
"Must be a pointer"
);
assert
(
ptr
!=
nullptr
);
if
(
ti
and
std
::
type_index
{
typeid
(
T
)}
!=
*
ti
)
MIGRAPHX_THROW
(
"any_ptr: type mismatch: "
+
std
::
string
{
name
}
+
" != "
+
get_name
<
T
>
());
else
if
(
name
!=
get_name
<
T
>
())
MIGRAPHX_THROW
(
"any_ptr: type mismatch: "
+
std
::
string
{
name
}
+
" != "
+
get_name
<
T
>
());
return
reinterpret_cast
<
T
>
(
ptr
);
}
void
*
unsafe_get
()
const
{
return
ptr
;
}
private:
void
*
ptr
=
nullptr
;
optional
<
std
::
type_index
>
ti
=
nullopt
;
std
::
string_view
name
=
""
;
template
<
class
T
>
static
const
std
::
string
&
get_name
()
{
return
get_type_name
<
std
::
remove_cv_t
<
std
::
remove_pointer_t
<
T
>>>
();
}
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_ANY_PTR_HPP
src/include/migraphx/auto_contiguous.hpp
View file @
11e155c2
...
@@ -13,7 +13,7 @@ struct module;
...
@@ -13,7 +13,7 @@ struct module;
struct
auto_contiguous
struct
auto_contiguous
{
{
std
::
string
name
()
const
{
return
"auto_contiguous"
;
}
std
::
string
name
()
const
{
return
"auto_contiguous"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
Prev
1
2
3
4
5
6
7
…
20
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment