Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
ff3bd8e6
Commit
ff3bd8e6
authored
May 12, 2021
by
Khalique Ahmed
Browse files
manual merge
parents
32b69ceb
c310bc5c
Changes
83
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
253 additions
and
24 deletions
+253
-24
src/include/migraphx/functional.hpp
src/include/migraphx/functional.hpp
+5
-0
src/include/migraphx/literal.hpp
src/include/migraphx/literal.hpp
+2
-0
src/include/migraphx/op/as_shape.hpp
src/include/migraphx/op/as_shape.hpp
+1
-1
src/include/migraphx/op/broadcast.hpp
src/include/migraphx/op/broadcast.hpp
+1
-1
src/include/migraphx/op/flatten.hpp
src/include/migraphx/op/flatten.hpp
+1
-1
src/include/migraphx/op/identity.hpp
src/include/migraphx/op/identity.hpp
+2
-5
src/include/migraphx/op/load.hpp
src/include/migraphx/op/load.hpp
+1
-1
src/include/migraphx/op/multibroadcast.hpp
src/include/migraphx/op/multibroadcast.hpp
+1
-1
src/include/migraphx/op/prefix_scan_op.hpp
src/include/migraphx/op/prefix_scan_op.hpp
+96
-0
src/include/migraphx/op/prefix_scan_sum.hpp
src/include/migraphx/op/prefix_scan_sum.hpp
+32
-0
src/include/migraphx/op/reshape.hpp
src/include/migraphx/op/reshape.hpp
+1
-1
src/include/migraphx/op/scalar.hpp
src/include/migraphx/op/scalar.hpp
+1
-1
src/include/migraphx/op/squeeze.hpp
src/include/migraphx/op/squeeze.hpp
+1
-1
src/include/migraphx/op/transpose.hpp
src/include/migraphx/op/transpose.hpp
+1
-1
src/include/migraphx/op/unsqueeze.hpp
src/include/migraphx/op/unsqueeze.hpp
+1
-1
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+2
-1
src/include/migraphx/output_iterator.hpp
src/include/migraphx/output_iterator.hpp
+46
-0
src/include/migraphx/pass.hpp
src/include/migraphx/pass.hpp
+53
-7
src/include/migraphx/pass_manager.hpp
src/include/migraphx/pass_manager.hpp
+2
-1
src/include/migraphx/program.hpp
src/include/migraphx/program.hpp
+3
-0
No files found.
src/include/migraphx/functional.hpp
View file @
ff3bd8e6
...
@@ -226,6 +226,11 @@ struct id
...
@@ -226,6 +226,11 @@ struct id
}
}
};
};
template
<
class
...
Ts
>
void
nop
(
Ts
&&
...)
{
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/include/migraphx/literal.hpp
View file @
ff3bd8e6
...
@@ -66,6 +66,8 @@ struct literal : raw_data<literal>
...
@@ -66,6 +66,8 @@ struct literal : raw_data<literal>
const
shape
&
get_shape
()
const
{
return
this
->
m_shape
;
}
const
shape
&
get_shape
()
const
{
return
this
->
m_shape
;
}
std
::
vector
<
literal
>
get_sub_objects
()
const
{
return
{};
}
/// Convert the data to an argument
/// Convert the data to an argument
argument
get_argument
()
const
argument
get_argument
()
const
{
{
...
...
src/include/migraphx/op/as_shape.hpp
View file @
ff3bd8e6
...
@@ -33,7 +33,7 @@ struct as_shape
...
@@ -33,7 +33,7 @@ struct as_shape
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)}
;
return
args
.
front
().
reshape
(
output_shape
)
;
}
}
bool
is_borrowed
()
const
{
return
true
;
}
bool
is_borrowed
()
const
{
return
true
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
...
...
src/include/migraphx/op/broadcast.hpp
View file @
ff3bd8e6
...
@@ -64,7 +64,7 @@ struct broadcast
...
@@ -64,7 +64,7 @@ struct broadcast
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
return
{
std
::
mov
e
(
output_shape
)
,
std
::
move
(
args
.
at
(
0
).
data
)}
;
return
args
[
0
].
reshap
e
(
output_shape
);
}
}
bool
is_borrowed
()
const
{
return
true
;
}
bool
is_borrowed
()
const
{
return
true
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
...
...
src/include/migraphx/op/flatten.hpp
View file @
ff3bd8e6
...
@@ -48,7 +48,7 @@ struct flatten
...
@@ -48,7 +48,7 @@ struct flatten
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
return
{
std
::
mov
e
(
output_shape
)
,
std
::
move
(
args
.
front
().
data
)}
;
return
args
[
0
].
reshap
e
(
output_shape
);
}
}
bool
is_borrowed
()
const
{
return
true
;
}
bool
is_borrowed
()
const
{
return
true
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
...
...
src/include/migraphx/op/identity.hpp
View file @
ff3bd8e6
...
@@ -19,11 +19,8 @@ struct identity
...
@@ -19,11 +19,8 @@ struct identity
{
{
std
::
string
name
()
const
{
return
"identity"
;
}
std
::
string
name
()
const
{
return
"identity"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
return
inputs
.
at
(
0
);
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
return
inputs
.
at
(
0
);
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
,
std
::
vector
<
argument
>
args
)
const
{
return
args
[
0
];
}
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
at
(
0
).
data
)};
}
bool
is_borrowed
()
const
{
return
true
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
};
...
...
src/include/migraphx/op/load.hpp
View file @
ff3bd8e6
...
@@ -34,7 +34,7 @@ struct load
...
@@ -34,7 +34,7 @@ struct load
{
{
if
((
offset
+
s
.
bytes
())
>
args
[
0
].
get_shape
().
bytes
())
if
((
offset
+
s
.
bytes
())
>
args
[
0
].
get_shape
().
bytes
())
MIGRAPHX_THROW
(
"Load access is out of bounds"
);
MIGRAPHX_THROW
(
"Load access is out of bounds"
);
return
{
s
,
args
[
0
].
data
()
+
offset
}
;
return
argument
::
load
(
s
,
args
[
0
].
data
()
+
offset
)
;
}
}
bool
is_borrowed
()
const
{
return
true
;
}
bool
is_borrowed
()
const
{
return
true
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
...
...
src/include/migraphx/op/multibroadcast.hpp
View file @
ff3bd8e6
...
@@ -66,7 +66,7 @@ struct multibroadcast
...
@@ -66,7 +66,7 @@ struct multibroadcast
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
return
{
std
::
mov
e
(
output_shape
)
,
std
::
move
(
args
.
at
(
0
).
data
)}
;
return
args
[
0
].
reshap
e
(
output_shape
);
}
}
bool
is_borrowed
()
const
{
return
true
;
}
bool
is_borrowed
()
const
{
return
true
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
...
...
src/include/migraphx/op/prefix_scan_op.hpp
0 → 100644
View file @
ff3bd8e6
#ifndef MIGRAPHX_GUARD_OPERATORS_SCAN_OP_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCAN_OP_HPP
#include <migraphx/op/name.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
template
<
class
Derived
>
struct
prefix_scan_op
:
op_name
<
Derived
>
{
int64_t
axis
;
bool
exclusive
=
false
;
bool
reverse
=
false
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
axis
,
"axis"
),
f
(
self
.
exclusive
,
"exclusive"
),
f
(
self
.
reverse
,
"reverse"
));
}
value
attributes
()
const
{
value
normalize
;
normalize
[
"axis"
]
=
value
::
array
{
normalize_attribute
::
include_min
};
return
{{
"normalize_axes"
,
normalize
}};
}
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
return
inputs
.
at
(
0
);
}
argument
compute
(
const
shape
&
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
=
args
[
0
];
auto
s
=
result
.
get_shape
();
auto
slice
=
shape
{
s
.
type
(),
{
s
.
lens
()[
axis
]},
{
s
.
strides
()[
axis
]}};
auto
lens
=
s
.
lens
();
lens
[
axis
]
=
1
;
auto
batch
=
shape
{
s
.
type
(),
lens
,
s
.
strides
()};
auto
&
self
=
static_cast
<
const
Derived
&>
(
*
this
);
result
.
visit
([
&
](
auto
output
)
{
using
type
=
decltype
(
output
);
par_for
(
batch
.
elements
(),
[
&
](
auto
i
)
{
auto
*
start
=
output
.
data
()
+
batch
.
index
(
i
);
type
x
{
slice
,
start
};
if
(
reverse
)
{
if
(
exclusive
)
{
std
::
copy
(
++
x
.
begin
(),
x
.
end
(),
x
.
begin
());
x
.
back
()
=
0
;
}
std
::
partial_sum
(
std
::
make_reverse_iterator
(
x
.
end
()),
std
::
make_reverse_iterator
(
x
.
begin
()),
std
::
make_reverse_iterator
(
x
.
end
()),
self
.
op
());
}
else
{
if
(
exclusive
)
{
std
::
copy_backward
(
x
.
begin
(),
--
x
.
end
(),
x
.
end
());
x
.
front
()
=
0
;
}
std
::
partial_sum
(
x
.
begin
(),
x
.
end
(),
x
.
begin
(),
self
.
op
());
}
});
});
return
result
;
}
auto
init
()
const
{}
prefix_scan_op
()
:
axis
(
0
)
{}
prefix_scan_op
(
int64_t
ax
)
:
axis
(
ax
)
{}
prefix_scan_op
(
int64_t
ax
,
bool
excl
)
:
axis
(
ax
),
exclusive
(
excl
)
{}
prefix_scan_op
(
int64_t
ax
,
bool
excl
,
bool
rev
)
:
axis
(
ax
),
exclusive
(
excl
),
reverse
(
rev
)
{}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/prefix_scan_sum.hpp
0 → 100644
View file @
ff3bd8e6
#ifndef MIGRAPHX_GUARD_OPERATORS_SCAN_INCLUSIVE_SUM_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCAN_INCLUSIVE_SUM_HPP
#include <migraphx/op/name.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/op/prefix_scan_op.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
prefix_scan_sum
:
prefix_scan_op
<
prefix_scan_sum
>
{
prefix_scan_sum
()
{}
prefix_scan_sum
(
int64_t
ax
)
:
prefix_scan_op
(
ax
)
{}
prefix_scan_sum
(
int64_t
ax
,
bool
excl
)
:
prefix_scan_op
(
ax
,
excl
)
{}
prefix_scan_sum
(
int64_t
ax
,
bool
excl
,
bool
rev
)
:
prefix_scan_op
(
ax
,
excl
,
rev
)
{}
auto
op
()
const
{
return
[](
auto
x
,
auto
y
)
{
return
x
+
y
;
};
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/reshape.hpp
View file @
ff3bd8e6
...
@@ -68,7 +68,7 @@ struct reshape
...
@@ -68,7 +68,7 @@ struct reshape
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
return
{
std
::
mov
e
(
output_shape
)
,
std
::
move
(
args
.
front
().
data
)}
;
return
args
[
0
].
reshap
e
(
output_shape
);
}
}
bool
is_borrowed
()
const
{
return
true
;
}
bool
is_borrowed
()
const
{
return
true
;
}
...
...
src/include/migraphx/op/scalar.hpp
View file @
ff3bd8e6
...
@@ -37,7 +37,7 @@ struct scalar
...
@@ -37,7 +37,7 @@ struct scalar
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
return
{
std
::
mov
e
(
output_shape
)
,
std
::
move
(
args
.
at
(
0
).
data
)}
;
return
args
[
0
].
reshap
e
(
output_shape
);
}
}
bool
is_borrowed
()
const
{
return
true
;
}
bool
is_borrowed
()
const
{
return
true
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
...
...
src/include/migraphx/op/squeeze.hpp
View file @
ff3bd8e6
...
@@ -75,7 +75,7 @@ struct squeeze
...
@@ -75,7 +75,7 @@ struct squeeze
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
return
{
std
::
mov
e
(
output_shape
)
,
std
::
move
(
args
.
front
().
data
)}
;
return
args
[
0
].
reshap
e
(
output_shape
);
}
}
bool
is_borrowed
()
const
{
return
true
;
}
bool
is_borrowed
()
const
{
return
true
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
...
...
src/include/migraphx/op/transpose.hpp
View file @
ff3bd8e6
...
@@ -61,7 +61,7 @@ struct transpose
...
@@ -61,7 +61,7 @@ struct transpose
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
return
{
std
::
mov
e
(
output_shape
)
,
std
::
move
(
args
.
front
().
data
)}
;
return
args
[
0
].
reshap
e
(
output_shape
);
}
}
bool
is_borrowed
()
const
{
return
true
;
}
bool
is_borrowed
()
const
{
return
true
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
...
...
src/include/migraphx/op/unsqueeze.hpp
View file @
ff3bd8e6
...
@@ -68,7 +68,7 @@ struct unsqueeze
...
@@ -68,7 +68,7 @@ struct unsqueeze
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
return
{
std
::
mov
e
(
output_shape
)
,
std
::
move
(
args
.
front
().
data
)}
;
return
args
[
0
].
reshap
e
(
output_shape
);
}
}
bool
is_borrowed
()
const
{
return
true
;
}
bool
is_borrowed
()
const
{
return
true
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
...
...
src/include/migraphx/operators.hpp
View file @
ff3bd8e6
...
@@ -58,10 +58,11 @@
...
@@ -58,10 +58,11 @@
#include <migraphx/op/outline.hpp>
#include <migraphx/op/outline.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/pow.hpp>
#include <migraphx/op/prefix_scan_sum.hpp>
#include <migraphx/op/prelu.hpp>
#include <migraphx/op/prelu.hpp>
#include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/op/pow.hpp>
#include <migraphx/op/recip.hpp>
#include <migraphx/op/recip.hpp>
#include <migraphx/op/reduce_max.hpp>
#include <migraphx/op/reduce_max.hpp>
#include <migraphx/op/reduce_mean.hpp>
#include <migraphx/op/reduce_mean.hpp>
...
...
src/include/migraphx/output_iterator.hpp
0 → 100644
View file @
ff3bd8e6
#ifndef MIGRAPHX_GUARD_MIGRAPHX_OUTPUT_ITERATOR_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_OUTPUT_ITERATOR_HPP
#include <migraphx/config.hpp>
#include <iterator>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
class
F
>
struct
function_output_iterator
{
F
f
;
using
self
=
function_output_iterator
;
using
difference_type
=
void
;
using
reference
=
void
;
using
value_type
=
void
;
using
pointer
=
void
;
using
iterator_category
=
std
::
output_iterator_tag
;
struct
output_proxy
{
template
<
class
T
>
output_proxy
&
operator
=
(
const
T
&
value
)
{
assert
(
f
);
(
*
f
)(
value
);
return
*
this
;
}
F
*
f
;
};
output_proxy
operator
*
()
{
return
output_proxy
{
&
f
};
}
self
&
operator
++
()
{
return
*
this
;
}
self
&
operator
++
(
int
)
{
return
*
this
;
}
// NOLINT
};
template
<
class
F
>
function_output_iterator
<
F
>
make_function_output_iterator
(
F
f
)
{
return
{
std
::
move
(
f
)};
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_OUTPUT_ITERATOR_HPP
src/include/migraphx/pass.hpp
View file @
ff3bd8e6
...
@@ -3,10 +3,10 @@
...
@@ -3,10 +3,10 @@
#include <cassert>
#include <cassert>
#include <string>
#include <string>
#include <functional>
#include <memory>
#include <memory>
#include <type_traits>
#include <type_traits>
#include <utility>
#include <utility>
#include <migraphx/functional.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
namespace
migraphx
{
...
@@ -23,8 +23,10 @@ struct pass
...
@@ -23,8 +23,10 @@ struct pass
{
{
/// A unique name used to identify the pass
/// A unique name used to identify the pass
std
::
string
name
()
const
;
std
::
string
name
()
const
;
/// Run the pass on the module
void
apply
(
module
&
m
)
const
;
/// Run the pass on the program
/// Run the pass on the program
void
apply
(
module
&
p
)
const
;
void
apply
(
program
&
p
)
const
;
};
};
#else
#else
...
@@ -35,7 +37,8 @@ struct pass
...
@@ -35,7 +37,8 @@ struct pass
* struct pass
* struct pass
* {
* {
* std::string name() const;
* std::string name() const;
* void apply(module & p) const;
* void apply(module & m) const;
* void apply(program & p) const;
* };
* };
*
*
*/
*/
...
@@ -109,7 +112,13 @@ struct pass
...
@@ -109,7 +112,13 @@ struct pass
return
(
*
this
).
private_detail_te_get_handle
().
name
();
return
(
*
this
).
private_detail_te_get_handle
().
name
();
}
}
void
apply
(
module
&
p
)
const
void
apply
(
module
&
m
)
const
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
(
*
this
).
private_detail_te_get_handle
().
apply
(
m
);
}
void
apply
(
program
&
p
)
const
{
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
assert
((
*
this
).
private_detail_te_handle_mem_var
);
(
*
this
).
private_detail_te_get_handle
().
apply
(
p
);
(
*
this
).
private_detail_te_get_handle
().
apply
(
p
);
...
@@ -128,10 +137,37 @@ struct pass
...
@@ -128,10 +137,37 @@ struct pass
virtual
std
::
shared_ptr
<
private_detail_te_handle_base_type
>
clone
()
const
=
0
;
virtual
std
::
shared_ptr
<
private_detail_te_handle_base_type
>
clone
()
const
=
0
;
virtual
const
std
::
type_info
&
type
()
const
=
0
;
virtual
const
std
::
type_info
&
type
()
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
void
apply
(
module
&
p
)
const
=
0
;
virtual
void
apply
(
module
&
m
)
const
=
0
;
virtual
void
apply
(
program
&
p
)
const
=
0
;
};
};
template
<
class
T
>
static
auto
private_detail_te_default_apply
(
char
,
T
&&
private_detail_te_self
,
module
&
m
)
->
decltype
(
private_detail_te_self
.
apply
(
m
))
{
private_detail_te_self
.
apply
(
m
);
}
template
<
class
T
>
static
void
private_detail_te_default_apply
(
float
,
T
&&
private_detail_te_self
,
module
&
m
)
{
migraphx
::
nop
(
private_detail_te_self
,
m
);
}
template
<
class
T
>
static
auto
private_detail_te_default_apply
(
char
,
T
&&
private_detail_te_self
,
program
&
p
)
->
decltype
(
private_detail_te_self
.
apply
(
p
))
{
private_detail_te_self
.
apply
(
p
);
}
template
<
class
T
>
static
void
private_detail_te_default_apply
(
float
,
T
&&
private_detail_te_self
,
program
&
p
)
{
migraphx
::
nop
(
private_detail_te_self
,
p
);
}
template
<
typename
PrivateDetailTypeErasedT
>
template
<
typename
PrivateDetailTypeErasedT
>
struct
private_detail_te_handle_type
:
private_detail_te_handle_base_type
struct
private_detail_te_handle_type
:
private_detail_te_handle_base_type
{
{
...
@@ -162,7 +198,17 @@ struct pass
...
@@ -162,7 +198,17 @@ struct pass
std
::
string
name
()
const
override
{
return
private_detail_te_value
.
name
();
}
std
::
string
name
()
const
override
{
return
private_detail_te_value
.
name
();
}
void
apply
(
module
&
p
)
const
override
{
private_detail_te_value
.
apply
(
p
);
}
void
apply
(
module
&
m
)
const
override
{
private_detail_te_default_apply
(
char
(
0
),
private_detail_te_value
,
m
);
}
void
apply
(
program
&
p
)
const
override
{
private_detail_te_default_apply
(
char
(
0
),
private_detail_te_value
,
p
);
}
PrivateDetailTypeErasedT
private_detail_te_value
;
PrivateDetailTypeErasedT
private_detail_te_value
;
};
};
...
...
src/include/migraphx/pass_manager.hpp
View file @
ff3bd8e6
...
@@ -17,7 +17,8 @@
...
@@ -17,7 +17,8 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
run_passes
(
module
&
modl
,
const
std
::
vector
<
pass
>&
passes
,
tracer
trace
=
tracer
{});
void
run_passes
(
module
&
mod
,
const
std
::
vector
<
pass
>&
passes
,
tracer
trace
=
tracer
{});
void
run_passes
(
program
&
prog
,
const
std
::
vector
<
pass
>&
passes
,
tracer
trace
=
tracer
{});
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/include/migraphx/program.hpp
View file @
ff3bd8e6
...
@@ -101,6 +101,9 @@ struct program
...
@@ -101,6 +101,9 @@ struct program
std
::
vector
<
const
module
*>
get_modules
()
const
;
std
::
vector
<
const
module
*>
get_modules
()
const
;
std
::
vector
<
module
*>
get_modules
();
std
::
vector
<
module
*>
get_modules
();
void
remove_module
(
const
std
::
string
&
name
);
void
remove_unused_modules
();
private:
private:
void
assign
(
const
program
&
p
);
void
assign
(
const
program
&
p
);
std
::
unique_ptr
<
program_impl
>
impl
;
std
::
unique_ptr
<
program_impl
>
impl
;
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment