Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
2f268bc2
Commit
2f268bc2
authored
Jun 12, 2022
by
Paul
Browse files
Merge branch 'develop' into mlir-c
parents
f75c5a38
aa7ff911
Changes
205
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
222 additions
and
199 deletions
+222
-199
src/api/include/migraphx/migraphx.hpp
src/api/include/migraphx/migraphx.hpp
+150
-128
src/api/migraphx.py
src/api/migraphx.py
+4
-0
src/argument.cpp
src/argument.cpp
+0
-2
src/auto_contiguous.cpp
src/auto_contiguous.cpp
+8
-8
src/compile_src.cpp
src/compile_src.cpp
+1
-1
src/dead_code_elimination.cpp
src/dead_code_elimination.cpp
+6
-22
src/driver/marker_roctx.cpp
src/driver/marker_roctx.cpp
+1
-1
src/eliminate_allocation.cpp
src/eliminate_allocation.cpp
+4
-4
src/eliminate_common_subexpression.cpp
src/eliminate_common_subexpression.cpp
+5
-5
src/eliminate_concat.cpp
src/eliminate_concat.cpp
+6
-6
src/eliminate_contiguous.cpp
src/eliminate_contiguous.cpp
+21
-7
src/eliminate_identity.cpp
src/eliminate_identity.cpp
+8
-8
src/include/migraphx/adjust_allocation.hpp
src/include/migraphx/adjust_allocation.hpp
+1
-1
src/include/migraphx/analyze_streams.hpp
src/include/migraphx/analyze_streams.hpp
+1
-1
src/include/migraphx/auto_contiguous.hpp
src/include/migraphx/auto_contiguous.hpp
+1
-1
src/include/migraphx/check_context.hpp
src/include/migraphx/check_context.hpp
+1
-1
src/include/migraphx/compile_src.hpp
src/include/migraphx/compile_src.hpp
+1
-0
src/include/migraphx/eliminate_allocation.hpp
src/include/migraphx/eliminate_allocation.hpp
+1
-1
src/include/migraphx/eliminate_common_subexpression.hpp
src/include/migraphx/eliminate_common_subexpression.hpp
+1
-1
src/include/migraphx/eliminate_concat.hpp
src/include/migraphx/eliminate_concat.hpp
+1
-1
No files found.
src/api/include/migraphx/migraphx.hpp
View file @
2f268bc2
...
@@ -15,6 +15,16 @@ namespace migraphx {
...
@@ -15,6 +15,16 @@ namespace migraphx {
inline
namespace
api
{
// NOLINT
inline
namespace
api
{
// NOLINT
#endif
#endif
#ifdef __has_cpp_attribute
#if __has_cpp_attribute(deprecated)
#define MIGRAPHX_DEPRECATED(...) [[deprecated(__VA_ARGS__)]]
#endif
#endif
#ifndef MIGRAPHX_DEPRECATED
#define MIGRAPHX_DEPRECATED(...)
#endif
template
<
int
N
>
template
<
int
N
>
struct
rank
:
rank
<
N
-
1
>
struct
rank
:
rank
<
N
-
1
>
{
{
...
@@ -29,10 +39,7 @@ template <class T, class F, class... Ts>
...
@@ -29,10 +39,7 @@ template <class T, class F, class... Ts>
T
*
make
(
F
f
,
Ts
&&
...
xs
)
T
*
make
(
F
f
,
Ts
&&
...
xs
)
{
{
T
*
result
=
nullptr
;
T
*
result
=
nullptr
;
// cppcheck-suppress redundantInitialization
auto
e
=
f
(
&
result
,
std
::
forward
<
Ts
>
(
xs
)...);
// cppcheck-suppress redundantAssignment
// cppcheck-suppress unreadVariable
auto
e
=
f
(
&
result
,
std
::
forward
<
Ts
>
(
xs
)...);
if
(
e
!=
migraphx_status_success
)
if
(
e
!=
migraphx_status_success
)
throw
std
::
runtime_error
(
"Failed to call function"
);
throw
std
::
runtime_error
(
"Failed to call function"
);
return
result
;
return
result
;
...
@@ -41,9 +48,6 @@ T* make(F f, Ts&&... xs)
...
@@ -41,9 +48,6 @@ T* make(F f, Ts&&... xs)
template
<
class
F
,
class
...
Ts
>
template
<
class
F
,
class
...
Ts
>
void
call
(
F
f
,
Ts
&&
...
xs
)
void
call
(
F
f
,
Ts
&&
...
xs
)
{
{
// cppcheck-suppress redundantInitialization
// cppcheck-suppress redundantAssignment
// cppcheck-suppress unreadVariable
auto
e
=
f
(
std
::
forward
<
Ts
>
(
xs
)...);
auto
e
=
f
(
std
::
forward
<
Ts
>
(
xs
)...);
if
(
e
!=
migraphx_status_success
)
if
(
e
!=
migraphx_status_success
)
throw
std
::
runtime_error
(
"Failed to call function"
);
throw
std
::
runtime_error
(
"Failed to call function"
);
...
@@ -99,34 +103,22 @@ struct iota_iterator
...
@@ -99,34 +103,22 @@ struct iota_iterator
return
it
;
return
it
;
}
}
// TODO: operator->
// TODO: operator->
reference
operator
*
()
const
{
return
(
*
f
)(
index
);
}
reference
operator
*
()
const
{
return
f
(
index
);
}
};
template
<
class
F
,
class
Iterator
>
friend
iota_iterator
operator
+
(
iota_iterator
x
,
iota_iterator
y
)
inline
iota_iterator
<
F
,
Iterator
>
operator
+
(
iota_iterator
<
F
,
Iterator
>
x
,
{
iota_iterator
<
F
,
Iterator
>
y
)
return
iota_iterator
(
x
.
index
+
y
.
index
,
x
.
f
);
{
}
return
iota_iterator
<
F
,
Iterator
>
(
x
.
index
+
y
.
index
,
x
.
f
);
}
template
<
class
F
,
class
Iterator
>
friend
iota_iterator
operator
-
(
iota_iterator
x
,
iota_iterator
y
)
inline
iota_iterator
<
F
,
Iterator
>
operator
-
(
iota_iterator
<
F
,
Iterator
>
x
,
{
iota_iterator
<
F
,
Iterator
>
y
)
return
iota_iterator
(
x
.
index
-
y
.
index
,
x
.
f
);
{
}
return
iota_iterator
<
F
,
Iterator
>
(
x
.
index
-
y
.
index
,
x
.
f
);
}
template
<
class
F
,
class
Iterator
>
friend
bool
operator
==
(
iota_iterator
x
,
iota_iterator
y
)
{
return
x
.
index
==
y
.
index
;
}
inline
bool
operator
==
(
iota_iterator
<
F
,
Iterator
>
x
,
iota_iterator
<
F
,
Iterator
>
y
)
{
return
x
.
index
==
y
.
index
;
}
template
<
class
F
,
class
Iterator
>
friend
bool
operator
!=
(
iota_iterator
x
,
iota_iterator
y
)
{
return
x
.
index
!=
y
.
index
;
}
inline
bool
operator
!=
(
iota_iterator
<
F
,
Iterator
>
x
,
iota_iterator
<
F
,
Iterator
>
y
)
};
{
return
x
.
index
!=
y
.
index
;
}
template
<
class
Derived
>
template
<
class
Derived
>
struct
array_base
struct
array_base
...
@@ -136,8 +128,20 @@ struct array_base
...
@@ -136,8 +128,20 @@ struct array_base
template
<
class
T
>
template
<
class
T
>
using
value_type_t
=
decltype
(
std
::
declval
<
T
>
()[
0
]);
using
value_type_t
=
decltype
(
std
::
declval
<
T
>
()[
0
]);
struct
iterator_read
{
const
Derived
*
self
;
template
<
class
D
=
Derived
>
value_type_t
<
D
>
operator
()(
size_t
pidx
)
const
{
return
(
*
self
)[
pidx
];
}
};
template
<
class
T
>
template
<
class
T
>
using
iterator_t
=
iota_iterator
<
typename
T
::
iterator_read
>
;
using
iterator_t
=
iota_iterator
<
iterator_read
>
;
bool
empty
()
const
{
return
derived
().
size
()
==
0
;
}
template
<
class
D
=
Derived
>
template
<
class
D
=
Derived
>
value_type_t
<
D
>
front
()
const
value_type_t
<
D
>
front
()
const
...
@@ -154,13 +158,13 @@ struct array_base
...
@@ -154,13 +158,13 @@ struct array_base
template
<
class
D
=
Derived
>
template
<
class
D
=
Derived
>
iterator_t
<
D
>
begin
()
const
iterator_t
<
D
>
begin
()
const
{
{
return
{
0
,
{
derived
()
.
get_handle_ptr
()
}};
return
{
0
,
{
&
derived
()}};
}
}
template
<
class
D
=
Derived
>
template
<
class
D
=
Derived
>
iterator_t
<
D
>
end
()
const
iterator_t
<
D
>
end
()
const
{
{
return
{
derived
().
size
(),
{
derived
()
.
get_handle_ptr
()
}};
return
{
derived
().
size
(),
{
&
derived
()}};
}
}
};
};
...
@@ -200,9 +204,25 @@ struct borrow
...
@@ -200,9 +204,25 @@ struct borrow
{
{
};
};
template
<
class
T
>
struct
share
{
share
(
std
::
shared_ptr
<
T
>
p
)
:
ptr
(
std
::
move
(
p
))
{}
template
<
class
U
>
std
::
shared_ptr
<
U
>
alias
(
U
*
p
)
const
{
return
std
::
shared_ptr
<
U
>
{
ptr
,
p
};
}
private:
std
::
shared_ptr
<
T
>
ptr
;
};
template
<
class
Derived
,
class
T
,
class
D
,
D
Deleter
,
class
A
,
A
Assigner
>
template
<
class
Derived
,
class
T
,
class
D
,
D
Deleter
,
class
A
,
A
Assigner
>
struct
handle_base
:
handle_lookup
<
Derived
,
std
::
remove_cv_t
<
T
>>
struct
handle_base
:
handle_lookup
<
Derived
,
std
::
remove_cv_t
<
T
>>
{
{
using
handle_type
=
T
;
handle_base
()
:
m_handle
(
nullptr
)
{}
handle_base
()
:
m_handle
(
nullptr
)
{}
template
<
class
F
,
class
...
Ts
>
template
<
class
F
,
class
...
Ts
>
void
make_handle
(
F
f
,
Ts
&&
...
xs
)
void
make_handle
(
F
f
,
Ts
&&
...
xs
)
...
@@ -231,6 +251,14 @@ struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>>
...
@@ -231,6 +251,14 @@ struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>>
m_handle
=
std
::
shared_ptr
<
U
>
{
ptr
,
[](
U
*
)
{}};
m_handle
=
std
::
shared_ptr
<
U
>
{
ptr
,
[](
U
*
)
{}};
}
}
template
<
class
U
,
class
V
>
void
set_handle
(
U
*
ptr
,
share
<
V
>
b
)
{
m_handle
=
std
::
shared_ptr
<
T
>
{
ptr
,
[
b
](
U
*
)
{}};
}
share
<
T
>
share_handle
()
const
{
return
{
m_handle
};
}
template
<
class
U
>
template
<
class
U
>
void
assign_to_handle
(
U
*
x
)
void
assign_to_handle
(
U
*
x
)
{
{
...
@@ -241,6 +269,17 @@ struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>>
...
@@ -241,6 +269,17 @@ struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>>
std
::
shared_ptr
<
T
>
m_handle
;
std
::
shared_ptr
<
T
>
m_handle
;
};
};
// NOLINTNEXTLINE
#define MIGRAPHX_HANDLE_CONSTRUCTOR(name) \
template <class HandleType, \
class Lifetime, \
class = \
typename std::enable_if<std::is_convertible<HandleType*, handle_type*>{}>::type> \
name(HandleType* p, Lifetime lifetime) \
{ \
this->set_handle(p, std::move(lifetime)); \
}
template
<
class
Base
>
template
<
class
Base
>
struct
interface_base
:
Base
struct
interface_base
:
Base
{
{
...
@@ -269,6 +308,7 @@ struct interface_base : Base
...
@@ -269,6 +308,7 @@ struct interface_base : Base
T
**
y
=
reinterpret_cast
<
T
**>
(
out
);
T
**
y
=
reinterpret_cast
<
T
**>
(
out
);
T
*
x
=
reinterpret_cast
<
T
*>
(
input
);
T
*
x
=
reinterpret_cast
<
T
*>
(
input
);
assert
(
x
!=
nullptr
and
y
!=
nullptr
and
*
y
==
nullptr
);
assert
(
x
!=
nullptr
and
y
!=
nullptr
and
*
y
==
nullptr
);
// cppcheck-suppress useSmartPointer
*
y
=
new
T
(
*
x
);
// NOLINT
*
y
=
new
T
(
*
x
);
// NOLINT
});
});
};
};
...
@@ -398,11 +438,10 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
...
@@ -398,11 +438,10 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
{
{
shape
()
{}
shape
()
{}
MIGRAPHX_DEPRECATED
(
"Contructor without lifetime annotation is deprecated."
)
shape
(
const
migraphx_shape
*
p
)
{
this
->
set_handle
(
p
,
borrow
{});
}
shape
(
const
migraphx_shape
*
p
)
{
this
->
set_handle
(
p
,
borrow
{});
}
shape
(
migraphx_shape
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
shape
);
shape
(
migraphx_shape
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
/// Construct a scalar shape
/// Construct a scalar shape
shape
(
migraphx_shape_datatype_t
type
)
shape
(
migraphx_shape_datatype_t
type
)
...
@@ -479,10 +518,9 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
...
@@ -479,10 +518,9 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
{
{
argument
()
{}
argument
()
{}
argument
(
migraphx_argument
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
argument
);
argument
(
migraphx_argument
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
MIGRAPHX_DEPRECATED
(
"Contructor without lifetime annotation is deprecated."
)
argument
(
const
migraphx_argument
*
p
)
{
this
->
set_handle
(
p
,
borrow
{});
}
argument
(
const
migraphx_argument
*
p
)
{
this
->
set_handle
(
p
,
borrow
{});
}
argument
(
shape
pshape
,
void
*
pbuffer
)
argument
(
shape
pshape
,
void
*
pbuffer
)
...
@@ -494,7 +532,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
...
@@ -494,7 +532,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
{
{
const_migraphx_shape_t
pout
;
const_migraphx_shape_t
pout
;
call
(
&
migraphx_argument_shape
,
&
pout
,
this
->
get_handle_ptr
());
call
(
&
migraphx_argument_shape
,
&
pout
,
this
->
get_handle_ptr
());
return
{
pout
};
return
{
pout
,
this
->
share_handle
()
};
}
}
char
*
data
()
const
char
*
data
()
const
...
@@ -526,9 +564,7 @@ struct target : MIGRAPHX_HANDLE_BASE(target)
...
@@ -526,9 +564,7 @@ struct target : MIGRAPHX_HANDLE_BASE(target)
{
{
target
()
{}
target
()
{}
target
(
migraphx_target
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
target
);
target
(
migraphx_target
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
/// Construct a target from its name
/// Construct a target from its name
target
(
const
char
*
name
)
{
this
->
make_handle
(
&
migraphx_target_create
,
name
);
}
target
(
const
char
*
name
)
{
this
->
make_handle
(
&
migraphx_target_create
,
name
);
}
...
@@ -538,15 +574,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
...
@@ -538,15 +574,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
{
{
program_parameter_shapes
()
{}
program_parameter_shapes
()
{}
program_parameter_shapes
(
migraphx_program_parameter_shapes
*
p
,
own
)
MIGRAPHX_HANDLE_CONSTRUCTOR
(
program_parameter_shapes
);
{
this
->
set_handle
(
p
,
own
{});
}
program_parameter_shapes
(
migraphx_program_parameter_shapes
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
size_t
size
()
const
size_t
size
()
const
{
{
...
@@ -559,7 +587,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
...
@@ -559,7 +587,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
{
{
const_migraphx_shape_t
pout
;
const_migraphx_shape_t
pout
;
call
(
&
migraphx_program_parameter_shapes_get
,
&
pout
,
this
->
get_handle_ptr
(),
pname
);
call
(
&
migraphx_program_parameter_shapes_get
,
&
pout
,
this
->
get_handle_ptr
(),
pname
);
return
{
pout
};
return
{
pout
,
this
->
share_handle
()
};
}
}
std
::
vector
<
const
char
*>
names
()
const
std
::
vector
<
const
char
*>
names
()
const
...
@@ -576,10 +604,9 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
...
@@ -576,10 +604,9 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
/// A class to construct the inputs parameters for a program
/// A class to construct the inputs parameters for a program
struct
program_parameters
:
MIGRAPHX_HANDLE_BASE
(
program_parameters
)
struct
program_parameters
:
MIGRAPHX_HANDLE_BASE
(
program_parameters
)
{
{
program_parameters
(
migraphx_program_parameters
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
program_parameters
);
program_parameters
(
migraphx_program_parameters
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
MIGRAPHX_DEPRECATED
(
"Contructor without lifetime annotation is deprecated."
)
program_parameters
(
migraphx_program_parameters
*
p
)
{
this
->
set_handle
(
p
,
borrow
{});
}
program_parameters
(
migraphx_program_parameters
*
p
)
{
this
->
set_handle
(
p
,
borrow
{});
}
program_parameters
()
{
this
->
make_handle
(
&
migraphx_program_parameters_create
);
}
program_parameters
()
{
this
->
make_handle
(
&
migraphx_program_parameters_create
);
}
...
@@ -604,9 +631,7 @@ struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters)
...
@@ -604,9 +631,7 @@ struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters)
struct
arguments
:
MIGRAPHX_HANDLE_BASE
(
arguments
),
array_base
<
arguments
>
struct
arguments
:
MIGRAPHX_HANDLE_BASE
(
arguments
),
array_base
<
arguments
>
{
{
arguments
(
migraphx_arguments
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
arguments
);
arguments
(
migraphx_arguments
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
size_t
size
()
const
size_t
size
()
const
{
{
...
@@ -619,27 +644,13 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments>
...
@@ -619,27 +644,13 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments>
{
{
const_migraphx_argument_t
pout
;
const_migraphx_argument_t
pout
;
call
(
&
migraphx_arguments_get
,
&
pout
,
this
->
get_handle_ptr
(),
pidx
);
call
(
&
migraphx_arguments_get
,
&
pout
,
this
->
get_handle_ptr
(),
pidx
);
return
{
pout
};
return
{
pout
,
this
->
share_handle
()
};
}
}
struct
iterator_read
{
migraphx_arguments
*
self
;
argument
operator
()(
size_t
pidx
)
const
{
const_migraphx_argument_t
pout
;
call
(
&
migraphx_arguments_get
,
&
pout
,
self
,
pidx
);
return
{
pout
};
}
};
};
};
struct
shapes
:
MIGRAPHX_HANDLE_BASE
(
shapes
),
array_base
<
shapes
>
struct
shapes
:
MIGRAPHX_HANDLE_BASE
(
shapes
),
array_base
<
shapes
>
{
{
shapes
(
migraphx_shapes
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
shapes
);
shapes
(
migraphx_shapes
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
size_t
size
()
const
size_t
size
()
const
{
{
...
@@ -652,26 +663,13 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
...
@@ -652,26 +663,13 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
{
{
const_migraphx_shape_t
pout
;
const_migraphx_shape_t
pout
;
call
(
&
migraphx_shapes_get
,
&
pout
,
this
->
get_handle_ptr
(),
pidx
);
call
(
&
migraphx_shapes_get
,
&
pout
,
this
->
get_handle_ptr
(),
pidx
);
return
{
pout
};
return
{
pout
,
this
->
share_handle
()
};
}
}
struct
iterator_read
{
migraphx_shapes
*
self
;
shape
operator
()(
size_t
pidx
)
const
{
const_migraphx_shape_t
pout
;
call
(
&
migraphx_shapes_get
,
&
pout
,
self
,
pidx
);
return
{
pout
};
}
};
};
};
struct
operation
:
MIGRAPHX_HANDLE_BASE
(
operation
)
struct
operation
:
MIGRAPHX_HANDLE_BASE
(
operation
)
{
{
operation
(
migraphx_operation
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
operation
);
operation
(
migraphx_operation
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
template
<
class
...
Ts
>
template
<
class
...
Ts
>
operation
(
const
char
*
name
,
const
char
*
attributes
=
nullptr
,
Ts
...
xs
)
operation
(
const
char
*
name
,
const
char
*
attributes
=
nullptr
,
Ts
...
xs
)
...
@@ -689,15 +687,12 @@ struct operation : MIGRAPHX_HANDLE_BASE(operation)
...
@@ -689,15 +687,12 @@ struct operation : MIGRAPHX_HANDLE_BASE(operation)
struct
instruction
:
MIGRAPHX_CONST_HANDLE_BASE
(
instruction
)
struct
instruction
:
MIGRAPHX_CONST_HANDLE_BASE
(
instruction
)
{
{
instruction
(
migraphx_instruction
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
instruction
);
};
};
struct
instructions
:
MIGRAPHX_HANDLE_BASE
(
instructions
)
struct
instructions
:
MIGRAPHX_HANDLE_BASE
(
instructions
)
{
{
MIGRAPHX_HANDLE_CONSTRUCTOR
(
instructions
);
instructions
(
migraphx_instructions
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
instructions
(
migraphx_instructions
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
template
<
class
...
Ts
>
template
<
class
...
Ts
>
instructions
(
Ts
...
xs
)
instructions
(
Ts
...
xs
)
...
@@ -711,33 +706,36 @@ struct module;
...
@@ -711,33 +706,36 @@ struct module;
struct
modules
:
MIGRAPHX_HANDLE_BASE
(
modules
)
struct
modules
:
MIGRAPHX_HANDLE_BASE
(
modules
)
{
{
MIGRAPHX_HANDLE_CONSTRUCTOR
(
modules
);
modules
(
migraphx_modules
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
modules
(
migraphx_modules
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
template
<
class
...
Ts
>
template
<
class
...
Ts
>
modules
(
Ts
...
xs
)
modules
(
Ts
...
xs
)
{
{
std
::
array
<
migraphx_module_t
,
sizeof
...(
Ts
)
>
a
=
{
xs
.
mm
...};
std
::
array
<
migraphx_module_t
,
sizeof
...(
Ts
)
>
a
=
{
xs
.
get_handle_ptr
()
...};
this
->
make_handle
(
&
migraphx_modules_create
,
a
.
data
(),
a
.
size
());
this
->
make_handle
(
&
migraphx_modules_create
,
a
.
data
(),
a
.
size
());
}
}
};
};
struct
module
struct
module
{
{
migraphx_module_t
mm
;
MIGRAPHX_DEPRECATED
(
"Constructor without lifetime annotation is deprecated."
)
module
(
migraphx_module
*
m
)
:
mm
(
std
::
shared_ptr
<
migraphx_module
*>
(),
m
)
{}
module
(
const
migraphx_module
_t
&
m
)
:
mm
(
m
)
{}
module
(
migraphx_module
*
m
,
borrow
)
:
mm
(
std
::
shared_ptr
<
migraphx_module
*>
(),
m
)
{}
void
print
()
const
{
call
(
&
migraphx_module_print
,
mm
);
}
template
<
class
T
>
module
(
migraphx_module
*
m
,
share
<
T
>
b
)
:
mm
(
b
.
alias
(
m
))
{
}
void
print
()
const
{
call
(
&
migraphx_module_print
,
mm
.
get
());
}
instruction
add_instruction
(
const
migraphx
::
operation
&
op
,
const
migraphx
::
instructions
&
args
)
instruction
add_instruction
(
const
migraphx
::
operation
&
op
,
const
migraphx
::
instructions
&
args
)
{
{
migraphx_instruction_t
op_ins
;
migraphx_instruction_t
op_ins
;
call
(
&
migraphx_module_add_instruction
,
call
(
&
migraphx_module_add_instruction
,
&
op_ins
,
&
op_ins
,
mm
,
mm
.
get
()
,
op
.
get_handle_ptr
(),
op
.
get_handle_ptr
(),
args
.
get_handle_ptr
());
args
.
get_handle_ptr
());
return
instruction
(
op_ins
,
own
{});
return
instruction
(
op_ins
,
own
{});
...
@@ -750,40 +748,72 @@ struct module
...
@@ -750,40 +748,72 @@ struct module
migraphx_instruction_t
op_ins
;
migraphx_instruction_t
op_ins
;
call
(
&
migraphx_module_add_instruction_with_mod_args
,
call
(
&
migraphx_module_add_instruction_with_mod_args
,
&
op_ins
,
&
op_ins
,
mm
,
mm
.
get
()
,
op
.
get_handle_ptr
(),
op
.
get_handle_ptr
(),
args
.
get_handle_ptr
(),
args
.
get_handle_ptr
(),
module_args
.
get_handle_ptr
());
module_args
.
get_handle_ptr
());
return
instruction
(
op_ins
,
own
{});
return
instruction
(
op_ins
,
own
{});
}
}
template
<
typename
T
>
instruction
add_literal
(
const
migraphx
::
shape
&
s
,
T
*
buffer
)
{
migraphx_instruction_t
literal_ins
;
const
auto
*
buffer_ptr
=
reinterpret_cast
<
const
char
*>
(
buffer
);
call
(
&
migraphx_module_add_literal
,
&
literal_ins
,
mm
.
get
(),
s
.
get_handle_ptr
(),
buffer_ptr
);
return
instruction
(
literal_ins
,
own
{});
}
instruction
add_parameter
(
const
std
::
string
&
name
,
shape
s
)
instruction
add_parameter
(
const
std
::
string
&
name
,
shape
s
)
{
{
migraphx_instruction_t
param_ins
;
migraphx_instruction_t
param_ins
;
call
(
&
migraphx_module_add_parameter
,
&
param_ins
,
mm
,
name
.
c_str
(),
s
.
get_handle_ptr
());
call
(
&
migraphx_module_add_parameter
,
&
param_ins
,
mm
.
get
(),
name
.
c_str
(),
s
.
get_handle_ptr
());
return
instruction
(
param_ins
,
own
{});
return
instruction
(
param_ins
,
own
{});
}
}
instruction
add_return
(
const
migraphx
::
instructions
&
args
)
instruction
add_return
(
const
migraphx
::
instructions
&
args
)
{
{
migraphx_instruction_t
ret_ins
;
migraphx_instruction_t
ret_ins
;
call
(
&
migraphx_module_add_return
,
&
ret_ins
,
mm
,
args
.
get_handle_ptr
());
call
(
&
migraphx_module_add_return
,
&
ret_ins
,
mm
.
get
()
,
args
.
get_handle_ptr
());
return
instruction
(
ret_ins
,
own
{});
return
instruction
(
ret_ins
,
own
{});
}
}
migraphx_module_t
get_handle_ptr
()
const
{
return
mm
.
get
();
}
private:
std
::
shared_ptr
<
migraphx_module
>
mm
;
};
};
struct
context
struct
context
{
{
migraphx_context
_t
ctx
;
context
(
migraphx_context
*
p
,
borrow
)
:
ctx
(
std
::
shared_ptr
<
migraphx_context
*>
(),
p
)
{}
void
finish
()
const
{
call
(
&
migraphx_context_finish
,
ctx
);
}
template
<
class
T
>
context
(
migraphx_context
*
p
,
share
<
T
>
b
)
:
ctx
(
b
.
alias
(
p
))
{
}
void
finish
()
const
{
call
(
&
migraphx_context_finish
,
ctx
.
get
());
}
template
<
class
T
>
T
get_queue
()
{
void
*
out
;
call
(
&
migraphx_context_get_queue
,
&
out
,
ctx
.
get
());
// TODO: check type here
return
reinterpret_cast
<
T
>
(
out
);
}
private:
std
::
shared_ptr
<
migraphx_context
>
ctx
;
};
};
struct
compile_options
:
MIGRAPHX_HANDLE_BASE
(
compile_options
)
struct
compile_options
:
MIGRAPHX_HANDLE_BASE
(
compile_options
)
{
{
compile_options
()
{
this
->
make_handle
(
&
migraphx_compile_options_create
);
}
compile_options
()
{
this
->
make_handle
(
&
migraphx_compile_options_create
);
}
compile_options
(
migraphx_compile_options
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
());
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
compile_options
);
/// For targets with offloaded memory(such as the gpu), this will insert
/// For targets with offloaded memory(such as the gpu), this will insert
/// instructions during compilation to copy the input parameters to the
/// instructions during compilation to copy the input parameters to the
...
@@ -807,9 +837,7 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
...
@@ -807,9 +837,7 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
{
{
program
()
{
this
->
make_handle
(
&
migraphx_program_create
);
}
program
()
{
this
->
make_handle
(
&
migraphx_program_create
);
}
program
(
migraphx_program
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
program
);
program
(
migraphx_program
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
/// Compile the program for a specific target to be ran on
/// Compile the program for a specific target to be ran on
void
compile
(
const
target
&
ptarget
,
const
compile_options
&
poptions
)
const
void
compile
(
const
target
&
ptarget
,
const
compile_options
&
poptions
)
const
...
@@ -872,21 +900,21 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
...
@@ -872,21 +900,21 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
{
{
migraphx_module_t
p_modu
;
migraphx_module_t
p_modu
;
call
(
&
migraphx_program_get_main_module
,
&
p_modu
,
this
->
get_handle_ptr
());
call
(
&
migraphx_program_get_main_module
,
&
p_modu
,
this
->
get_handle_ptr
());
return
module
{
p_modu
};
return
module
{
p_modu
,
this
->
share_handle
()
};
}
}
context
experimental_get_context
()
context
experimental_get_context
()
{
{
migraphx_context_t
ctx
;
migraphx_context_t
ctx
;
call
(
&
migraphx_program_experimental_get_context
,
&
ctx
,
this
->
get_handle_ptr
());
call
(
&
migraphx_program_experimental_get_context
,
&
ctx
,
this
->
get_handle_ptr
());
return
context
{
ctx
};
return
context
{
ctx
,
this
->
share_handle
()
};
}
}
module
create_module
(
const
std
::
string
&
name
)
module
create_module
(
const
std
::
string
&
name
)
{
{
migraphx_module_t
p_modu
;
migraphx_module_t
p_modu
;
call
(
&
migraphx_program_create_module
,
&
p_modu
,
this
->
get_handle_ptr
(),
name
.
data
());
call
(
&
migraphx_program_create_module
,
&
p_modu
,
this
->
get_handle_ptr
(),
name
.
data
());
return
module
{
p_modu
};
return
module
{
p_modu
,
this
->
share_handle
()
};
}
}
friend
bool
operator
!=
(
const
program
&
px
,
const
program
&
py
)
{
return
!
(
px
==
py
);
}
friend
bool
operator
!=
(
const
program
&
px
,
const
program
&
py
)
{
return
!
(
px
==
py
);
}
...
@@ -895,10 +923,9 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
...
@@ -895,10 +923,9 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
// options for migraphx file format options
// options for migraphx file format options
struct
file_options
:
MIGRAPHX_HANDLE_BASE
(
file_options
)
struct
file_options
:
MIGRAPHX_HANDLE_BASE
(
file_options
)
{
{
MIGRAPHX_HANDLE_CONSTRUCTOR
(
file_options
);
file_options
()
{
this
->
make_handle
(
&
migraphx_file_options_create
);
}
file_options
()
{
this
->
make_handle
(
&
migraphx_file_options_create
);
}
file_options
(
migraphx_file_options
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
());
}
// set file format
// set file format
void
set_file_format
(
const
char
*
format
)
void
set_file_format
(
const
char
*
format
)
{
{
...
@@ -938,7 +965,7 @@ struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options)
...
@@ -938,7 +965,7 @@ struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options)
{
{
onnx_options
()
{
this
->
make_handle
(
&
migraphx_onnx_options_create
);
}
onnx_options
()
{
this
->
make_handle
(
&
migraphx_onnx_options_create
);
}
onnx_options
(
migraphx_onnx_options
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
onnx_options
);
/// Make onnx parser treat an inputs with a certain dimensions
/// Make onnx parser treat an inputs with a certain dimensions
void
set_input_parameter_shape
(
const
std
::
string
&
name
,
std
::
vector
<
std
::
size_t
>
dim
)
void
set_input_parameter_shape
(
const
std
::
string
&
name
,
std
::
vector
<
std
::
size_t
>
dim
)
...
@@ -1020,7 +1047,7 @@ struct tf_options : MIGRAPHX_HANDLE_BASE(tf_options)
...
@@ -1020,7 +1047,7 @@ struct tf_options : MIGRAPHX_HANDLE_BASE(tf_options)
{
{
tf_options
()
{
this
->
make_handle
(
&
migraphx_tf_options_create
);
}
tf_options
()
{
this
->
make_handle
(
&
migraphx_tf_options_create
);
}
tf_options
(
migraphx_tf_options
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
tf_options
);
/// Make tf parser treat an inputs with a certain dimensions
/// Make tf parser treat an inputs with a certain dimensions
void
set_input_parameter_shape
(
const
std
::
string
&
name
,
std
::
vector
<
std
::
size_t
>
dim
)
void
set_input_parameter_shape
(
const
std
::
string
&
name
,
std
::
vector
<
std
::
size_t
>
dim
)
...
@@ -1073,7 +1100,7 @@ struct quantize_op_names : MIGRAPHX_HANDLE_BASE(quantize_op_names)
...
@@ -1073,7 +1100,7 @@ struct quantize_op_names : MIGRAPHX_HANDLE_BASE(quantize_op_names)
{
{
quantize_op_names
()
{
this
->
make_handle
(
&
migraphx_quantize_op_names_create
);
}
quantize_op_names
()
{
this
->
make_handle
(
&
migraphx_quantize_op_names_create
);
}
quantize_op_names
(
migraphx_quantize_op_names
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
quantize_op_names
);
void
add
(
const
std
::
string
&
name
)
void
add
(
const
std
::
string
&
name
)
{
{
...
@@ -1098,12 +1125,7 @@ struct quantize_int8_options : MIGRAPHX_HANDLE_BASE(quantize_int8_options)
...
@@ -1098,12 +1125,7 @@ struct quantize_int8_options : MIGRAPHX_HANDLE_BASE(quantize_int8_options)
{
{
quantize_int8_options
()
{
this
->
make_handle
(
&
migraphx_quantize_int8_options_create
);
}
quantize_int8_options
()
{
this
->
make_handle
(
&
migraphx_quantize_int8_options_create
);
}
quantize_int8_options
(
migraphx_quantize_int8_options
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
quantize_int8_options
);
quantize_int8_options
(
migraphx_quantize_int8_options
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
/// Add an operator that should be quantized
/// Add an operator that should be quantized
void
add_op_name
(
const
std
::
string
&
name
)
void
add_op_name
(
const
std
::
string
&
name
)
...
...
src/api/migraphx.py
View file @
2f268bc2
...
@@ -212,6 +212,9 @@ def module(h):
...
@@ -212,6 +212,9 @@ def module(h):
module_refs
=
'std::vector<migraphx::module*>'
),
module_refs
=
'std::vector<migraphx::module*>'
),
fname
=
'add_instruction'
,
fname
=
'add_instruction'
,
returns
=
'migraphx::instruction_ref'
)
returns
=
'migraphx::instruction_ref'
)
h
.
method
(
'add_literal'
,
api
.
params
(
shape
=
'const migraphx::shape&'
,
buffer
=
'const char*'
),
returns
=
'migraphx::instruction_ref'
)
h
.
method
(
'add_parameter'
,
h
.
method
(
'add_parameter'
,
api
.
params
(
name
=
'const char*'
,
shape
=
'const migraphx::shape&'
),
api
.
params
(
name
=
'const char*'
,
shape
=
'const migraphx::shape&'
),
returns
=
'migraphx::instruction_ref'
)
returns
=
'migraphx::instruction_ref'
)
...
@@ -403,6 +406,7 @@ api.add_function('migraphx_quantize_int8',
...
@@ -403,6 +406,7 @@ api.add_function('migraphx_quantize_int8',
@
auto_handle
(
ref
=
True
)
@
auto_handle
(
ref
=
True
)
def
context
(
h
):
def
context
(
h
):
h
.
method
(
'finish'
,
const
=
True
)
h
.
method
(
'finish'
,
const
=
True
)
h
.
method
(
'get_queue'
,
returns
=
'void*'
,
fname
=
'get_queue().unsafe_get'
)
@
api
.
interface
(
'migraphx_experimental_custom_op'
,
@
api
.
interface
(
'migraphx_experimental_custom_op'
,
...
...
src/argument.cpp
View file @
2f268bc2
...
@@ -29,7 +29,6 @@ void argument::assign_buffer(std::function<char*()> d)
...
@@ -29,7 +29,6 @@ void argument::assign_buffer(std::function<char*()> d)
// Collect all shapes
// Collect all shapes
std
::
unordered_map
<
std
::
size_t
,
shape
>
shapes
;
std
::
unordered_map
<
std
::
size_t
,
shape
>
shapes
;
{
{
// cppcheck-suppress variableScope
std
::
size_t
i
=
0
;
std
::
size_t
i
=
0
;
fix
([
&
](
auto
self
,
auto
ss
)
{
fix
([
&
](
auto
self
,
auto
ss
)
{
if
(
ss
.
sub_shapes
().
empty
())
if
(
ss
.
sub_shapes
().
empty
())
...
@@ -60,7 +59,6 @@ void argument::assign_buffer(std::function<char*()> d)
...
@@ -60,7 +59,6 @@ void argument::assign_buffer(std::function<char*()> d)
}
}
assert
(
offset
==
s
.
bytes
());
assert
(
offset
==
s
.
bytes
());
// cppcheck-suppress variableScope
std
::
size_t
i
=
0
;
std
::
size_t
i
=
0
;
m_data
=
fix
<
data_t
>
([
&
](
auto
self
,
auto
ss
)
{
m_data
=
fix
<
data_t
>
([
&
](
auto
self
,
auto
ss
)
{
data_t
result
;
data_t
result
;
...
...
src/auto_contiguous.cpp
View file @
2f268bc2
...
@@ -8,10 +8,10 @@
...
@@ -8,10 +8,10 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
auto_contiguous
::
apply
(
module
&
p
)
const
void
auto_contiguous
::
apply
(
module
&
m
)
const
{
{
std
::
string
key
=
"require_std_shape"
;
std
::
string
key
=
"require_std_shape"
;
for
(
auto
ins
:
reverse_iterator_for
(
p
))
for
(
auto
ins
:
reverse_iterator_for
(
m
))
{
{
auto
&&
attr
=
ins
->
get_operator
().
attributes
();
auto
&&
attr
=
ins
->
get_operator
().
attributes
();
if
((
attr
.
get
(
key
,
false
)))
if
((
attr
.
get
(
key
,
false
)))
...
@@ -23,18 +23,18 @@ void auto_contiguous::apply(module& p) const
...
@@ -23,18 +23,18 @@ void auto_contiguous::apply(module& p) const
{
{
return
in
;
return
in
;
}
}
return
p
.
insert_instruction
(
ins
,
make_op
(
"contiguous"
),
in
);
return
m
.
insert_instruction
(
ins
,
make_op
(
"contiguous"
),
in
);
});
});
if
(
new_args
!=
args
)
if
(
new_args
!=
args
)
{
{
p
.
replace_instruction
(
ins
,
ins
->
get_operator
(),
new_args
);
m
.
replace_instruction
(
ins
,
ins
->
get_operator
(),
new_args
);
}
}
}
}
}
}
auto
last
=
std
::
prev
(
p
.
end
());
auto
last
=
std
::
prev
(
m
.
end
());
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
m
))
{
{
// for last instruction that is NOT a return
// for last instruction that is NOT a return
if
(
ins
->
outputs
().
empty
()
and
ins
!=
last
)
if
(
ins
->
outputs
().
empty
()
and
ins
!=
last
)
...
@@ -42,8 +42,8 @@ void auto_contiguous::apply(module& p) const
...
@@ -42,8 +42,8 @@ void auto_contiguous::apply(module& p) const
shape
s
=
ins
->
get_shape
();
shape
s
=
ins
->
get_shape
();
if
(
not
s
.
standard
()
and
s
.
elements
()
!=
0
)
if
(
not
s
.
standard
()
and
s
.
elements
()
!=
0
)
{
{
auto
c
=
p
.
insert_instruction
(
std
::
next
(
ins
),
make_op
(
"contiguous"
),
ins
);
auto
c
=
m
.
insert_instruction
(
std
::
next
(
ins
),
make_op
(
"contiguous"
),
ins
);
p
.
replace_instruction
(
ins
,
c
);
m
.
replace_instruction
(
ins
,
c
);
}
}
}
}
}
}
...
...
src/compile_src.cpp
View file @
2f268bc2
...
@@ -28,7 +28,7 @@ std::vector<char> src_compiler::compile(const std::vector<src_file>& srcs) const
...
@@ -28,7 +28,7 @@ std::vector<char> src_compiler::compile(const std::vector<src_file>& srcs) const
{
{
params
+=
" "
+
src
.
path
.
filename
().
string
();
params
+=
" "
+
src
.
path
.
filename
().
string
();
if
(
out
.
empty
())
if
(
out
.
empty
())
out
=
src
.
path
.
stem
().
string
()
+
".o"
;
out
=
src
.
path
.
stem
().
string
()
+
out_ext
;
}
}
}
}
...
...
src/dead_code_elimination.cpp
View file @
2f268bc2
...
@@ -9,26 +9,6 @@
...
@@ -9,26 +9,6 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
class
Range
,
class
Iterator
>
std
::
ptrdiff_t
bidistance
(
const
Range
&
r
,
Iterator
start
,
Iterator
last
)
{
auto
start_forward
=
start
;
auto
start_backwards
=
start
;
std
::
size_t
n
=
0
;
while
(
start_forward
!=
last
and
start_backwards
!=
last
)
{
n
++
;
if
(
start_forward
!=
r
.
end
())
start_forward
++
;
if
(
start_backwards
!=
r
.
begin
())
start_backwards
--
;
}
if
(
start_forward
==
last
)
return
n
;
else
return
-
n
;
}
void
dead_code_elimination
::
apply
(
program
&
p
)
const
{
p
.
remove_unused_modules
();
}
void
dead_code_elimination
::
apply
(
program
&
p
)
const
{
p
.
remove_unused_modules
();
}
void
dead_code_elimination
::
apply
(
module
&
m
)
const
void
dead_code_elimination
::
apply
(
module
&
m
)
const
...
@@ -48,17 +28,21 @@ void dead_code_elimination::apply(module& m) const
...
@@ -48,17 +28,21 @@ void dead_code_elimination::apply(module& m) const
if
(
i
->
get_shape
().
elements
()
==
0
and
i
->
name
().
front
()
!=
'@'
and
if
(
i
->
get_shape
().
elements
()
==
0
and
i
->
name
().
front
()
!=
'@'
and
i
->
name
()
!=
"undefined"
and
i
->
name
()
!=
"identity"
)
i
->
name
()
!=
"undefined"
and
i
->
name
()
!=
"identity"
)
continue
;
continue
;
assert
(
bidistance
(
m
,
i
,
last
)
>
0
);
assert
(
std
::
distance
(
m
.
begin
(),
i
)
<=
std
::
distance
(
m
.
begin
(),
last
));
std
::
unordered_set
<
instruction_ref
>
visited
;
fix
([
&
](
auto
self
,
auto
leaf
)
{
fix
([
&
](
auto
self
,
auto
leaf
)
{
if
(
not
m
.
has_instruction
(
leaf
))
if
(
not
m
.
has_instruction
(
leaf
))
return
;
return
;
if
(
leaf
->
outputs
().
empty
())
if
(
leaf
->
outputs
().
empty
())
{
{
// Dont visit inputs twice
if
(
not
visited
.
insert
(
leaf
).
second
)
return
;
std
::
unordered_set
<
instruction_ref
>
args
(
leaf
->
inputs
().
begin
(),
std
::
unordered_set
<
instruction_ref
>
args
(
leaf
->
inputs
().
begin
(),
leaf
->
inputs
().
end
());
leaf
->
inputs
().
end
());
leaf
->
clear_arguments
();
leaf
->
clear_arguments
();
assert
(
bi
distance
(
m
,
last
,
leaf
)
<
0
);
assert
(
std
::
distance
(
m
.
begin
(),
leaf
)
<
std
::
distance
(
m
.
begin
(),
last
)
);
assert
(
leaf
!=
ins
);
assert
(
leaf
!=
ins
);
if
(
leaf
->
name
()
!=
"@param"
)
if
(
leaf
->
name
()
!=
"@param"
)
m
.
move_instruction
(
leaf
,
m
.
end
());
m
.
move_instruction
(
leaf
,
m
.
end
());
...
...
src/driver/marker_roctx.cpp
View file @
2f268bc2
...
@@ -17,7 +17,7 @@ class marker_roctx
...
@@ -17,7 +17,7 @@ class marker_roctx
std
::
function
<
int
(
const
char
*
)
>
sym_roctx_range_push
;
std
::
function
<
int
(
const
char
*
)
>
sym_roctx_range_push
;
std
::
function
<
int
()
>
sym_roctx_range_pop
;
std
::
function
<
int
()
>
sym_roctx_range_pop
;
uint64_t
range_id
;
uint64_t
range_id
=
0
;
public:
public:
marker_roctx
()
marker_roctx
()
...
...
src/eliminate_allocation.cpp
View file @
2f268bc2
...
@@ -13,13 +13,13 @@
...
@@ -13,13 +13,13 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
eliminate_allocation
::
apply
(
module
&
p
)
const
void
eliminate_allocation
::
apply
(
module
&
m
)
const
{
{
assert
(
alignment
>
0
);
assert
(
alignment
>
0
);
std
::
size_t
n
=
0
;
std
::
size_t
n
=
0
;
std
::
vector
<
std
::
pair
<
instruction_ref
,
std
::
size_t
>>
allocs
;
std
::
vector
<
std
::
pair
<
instruction_ref
,
std
::
size_t
>>
allocs
;
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
m
))
{
{
if
(
ins
->
name
()
!=
allocation_op
)
if
(
ins
->
name
()
!=
allocation_op
)
continue
;
continue
;
...
@@ -30,13 +30,13 @@ void eliminate_allocation::apply(module& p) const
...
@@ -30,13 +30,13 @@ void eliminate_allocation::apply(module& p) const
}
}
if
(
n
>
0
)
if
(
n
>
0
)
{
{
auto
mem
=
p
.
add_parameter
(
"memory"
,
shape
{
shape
::
int8_type
,
{
n
}});
auto
mem
=
m
.
add_parameter
(
"memory"
,
shape
{
shape
::
int8_type
,
{
n
}});
for
(
auto
&&
pp
:
allocs
)
for
(
auto
&&
pp
:
allocs
)
{
{
auto
ins
=
pp
.
first
;
auto
ins
=
pp
.
first
;
auto
s
=
ins
->
get_shape
();
auto
s
=
ins
->
get_shape
();
auto
offset
=
pp
.
second
;
auto
offset
=
pp
.
second
;
p
.
replace_instruction
(
m
.
replace_instruction
(
ins
,
make_op
(
"load"
,
{{
"shape"
,
to_value
(
s
)},
{
"offset"
,
offset
}}),
mem
);
ins
,
make_op
(
"load"
,
{{
"shape"
,
to_value
(
s
)},
{
"offset"
,
offset
}}),
mem
);
}
}
}
}
...
...
src/eliminate_common_subexpression.cpp
View file @
2f268bc2
...
@@ -11,7 +11,7 @@ namespace migraphx {
...
@@ -11,7 +11,7 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
class
Range
>
template
<
class
Range
>
void
cse_range
(
module
&
p
,
Range
&&
r
)
void
cse_range
(
module
&
m
,
Range
&&
r
)
{
{
std
::
unordered_multimap
<
std
::
string
,
instruction_ref
>
instructions
;
std
::
unordered_multimap
<
std
::
string
,
instruction_ref
>
instructions
;
std
::
unordered_set
<
instruction_ref
>
processed_ins
;
std
::
unordered_set
<
instruction_ref
>
processed_ins
;
...
@@ -30,24 +30,24 @@ void cse_range(module& p, Range&& r)
...
@@ -30,24 +30,24 @@ void cse_range(module& p, Range&& r)
continue
;
continue
;
if
(
*
eq
!=
*
ins
)
if
(
*
eq
!=
*
ins
)
continue
;
continue
;
p
.
replace_instruction
(
ins
,
eq
);
m
.
replace_instruction
(
ins
,
eq
);
processed_ins
.
emplace
(
ins
);
processed_ins
.
emplace
(
ins
);
std
::
vector
<
instruction_ref
>
outputs
;
std
::
vector
<
instruction_ref
>
outputs
;
std
::
copy_if
(
eq
->
outputs
().
begin
(),
std
::
copy_if
(
eq
->
outputs
().
begin
(),
eq
->
outputs
().
end
(),
eq
->
outputs
().
end
(),
std
::
back_inserter
(
outputs
),
std
::
back_inserter
(
outputs
),
[
&
](
auto
x
)
{
return
p
.
has_instruction
(
x
);
});
[
&
](
auto
x
)
{
return
m
.
has_instruction
(
x
);
});
std
::
sort
(
outputs
.
begin
(),
outputs
.
end
(),
[
&
](
auto
x
,
auto
y
)
{
std
::
sort
(
outputs
.
begin
(),
outputs
.
end
(),
[
&
](
auto
x
,
auto
y
)
{
return
std
::
distance
(
eq
,
x
)
<
std
::
distance
(
eq
,
y
);
return
std
::
distance
(
eq
,
x
)
<
std
::
distance
(
eq
,
y
);
});
});
cse_range
(
p
,
outputs
);
cse_range
(
m
,
outputs
);
}
}
instructions
.
emplace
(
ins
->
name
(),
ins
);
instructions
.
emplace
(
ins
->
name
(),
ins
);
}
}
}
}
void
eliminate_common_subexpression
::
apply
(
module
&
p
)
const
{
cse_range
(
p
,
iterator_for
(
p
));
}
void
eliminate_common_subexpression
::
apply
(
module
&
m
)
const
{
cse_range
(
m
,
iterator_for
(
m
));
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
src/eliminate_concat.cpp
View file @
2f268bc2
...
@@ -13,9 +13,9 @@
...
@@ -13,9 +13,9 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
eliminate_concat
::
apply
(
module
&
p
)
const
void
eliminate_concat
::
apply
(
module
&
m
)
const
{
{
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
m
))
{
{
// Look for the concat operator
// Look for the concat operator
if
(
ins
->
name
()
!=
concat_opt
.
name
())
if
(
ins
->
name
()
!=
concat_opt
.
name
())
...
@@ -64,22 +64,22 @@ void eliminate_concat::apply(module& p) const
...
@@ -64,22 +64,22 @@ void eliminate_concat::apply(module& p) const
std
::
sort
(
sorted_allocations
.
begin
(),
std
::
sort
(
sorted_allocations
.
begin
(),
sorted_allocations
.
end
(),
sorted_allocations
.
end
(),
[
&
](
instruction_ref
x
,
instruction_ref
y
)
{
[
&
](
instruction_ref
x
,
instruction_ref
y
)
{
return
std
::
distance
(
p
.
begin
(),
x
)
<
std
::
distance
(
p
.
begin
(),
y
);
return
std
::
distance
(
m
.
begin
(),
x
)
<
std
::
distance
(
m
.
begin
(),
y
);
});
});
// Move "super" allocation to the front
// Move "super" allocation to the front
auto
first
=
sorted_allocations
.
front
();
auto
first
=
sorted_allocations
.
front
();
auto
super
=
p
.
move_instruction
(
last
,
first
);
auto
super
=
m
.
move_instruction
(
last
,
first
);
// Replace each allocation with a load
// Replace each allocation with a load
std
::
size_t
offset
=
0
;
std
::
size_t
offset
=
0
;
for
(
auto
alloc
:
allocations
)
for
(
auto
alloc
:
allocations
)
{
{
op
::
load
op
{
alloc
->
get_shape
(),
offset
};
op
::
load
op
{
alloc
->
get_shape
(),
offset
};
p
.
replace_instruction
(
alloc
,
op
,
{
super
});
m
.
replace_instruction
(
alloc
,
op
,
{
super
});
offset
+=
alloc
->
get_shape
().
bytes
();
offset
+=
alloc
->
get_shape
().
bytes
();
}
}
std
::
vector
<
instruction_ref
>
args
=
{
super
};
std
::
vector
<
instruction_ref
>
args
=
{
super
};
std
::
copy
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
()
-
1
,
std
::
back_inserter
(
args
));
std
::
copy
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
()
-
1
,
std
::
back_inserter
(
args
));
p
.
replace_instruction
(
ins
,
migraphx
::
make_op
(
"identity"
),
args
);
m
.
replace_instruction
(
ins
,
migraphx
::
make_op
(
"identity"
),
args
);
}
}
}
}
}
}
...
...
src/eliminate_contiguous.cpp
View file @
2f268bc2
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/op/contiguous.hpp>
#include <migraphx/op/contiguous.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/par_for.hpp>
#include <utility>
#include <utility>
namespace
migraphx
{
namespace
migraphx
{
...
@@ -69,9 +70,11 @@ static bool try_compute_shape(instruction_ref ins,
...
@@ -69,9 +70,11 @@ static bool try_compute_shape(instruction_ref ins,
return
try_compute_shape
(
ins
,
inputs
,
mods
);
return
try_compute_shape
(
ins
,
inputs
,
mods
);
}
}
void
eliminate_contiguous
::
apply
(
module
&
p
)
const
void
eliminate_contiguous
::
apply
(
module
&
m
)
const
{
{
for
(
auto
ins
:
iterator_for
(
p
))
std
::
vector
<
instruction_ref
>
const_instruction
;
for
(
auto
ins
:
iterator_for
(
m
))
{
{
// return instruction should have inputs with standard shape
// return instruction should have inputs with standard shape
if
(
ins
->
name
()
==
"@return"
)
if
(
ins
->
name
()
==
"@return"
)
...
@@ -81,6 +84,7 @@ void eliminate_contiguous::apply(module& p) const
...
@@ -81,6 +84,7 @@ void eliminate_contiguous::apply(module& p) const
auto
args
=
ins
->
inputs
();
auto
args
=
ins
->
inputs
();
auto
new_args
=
args
;
auto
new_args
=
args
;
auto
mod_args
=
ins
->
module_inputs
();
auto
mod_args
=
ins
->
module_inputs
();
for
(
auto
arg
:
ins
->
inputs
())
for
(
auto
arg
:
ins
->
inputs
())
{
{
if
(
arg
->
name
()
==
op_name
)
if
(
arg
->
name
()
==
op_name
)
...
@@ -93,15 +97,25 @@ void eliminate_contiguous::apply(module& p) const
...
@@ -93,15 +97,25 @@ void eliminate_contiguous::apply(module& p) const
}
}
else
if
(
prev
->
can_eval
())
else
if
(
prev
->
can_eval
())
{
{
auto
c
=
op
::
contiguous
{};
const_instruction
.
push_back
(
arg
);
auto
r
=
c
.
compute
(
c
.
compute_shape
({
prev
->
get_shape
()}),
{
prev
->
eval
()});
auto
l
=
p
.
add_literal
(
r
.
get_shape
(),
r
.
data
());
p
.
replace_instruction
(
arg
,
l
);
}
}
}
}
}
}
}
}
// Perform evaluations in parallel
std
::
vector
<
argument
>
literals
(
const_instruction
.
size
());
par_for
(
const_instruction
.
size
(),
1
,
[
&
](
const
auto
i
)
{
auto
c
=
op
::
contiguous
{};
auto
prev
=
const_instruction
[
i
]
->
inputs
().
front
();
literals
[
i
]
=
c
.
compute
(
c
.
compute_shape
({
prev
->
get_shape
()}),
{
prev
->
eval
()});
});
for
(
size_t
i
=
0
;
i
<
const_instruction
.
size
();
i
++
)
{
auto
l
=
m
.
add_literal
(
literals
[
i
].
get_shape
(),
literals
[
i
].
data
());
m
.
replace_instruction
(
const_instruction
[
i
],
l
);
}
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/eliminate_identity.cpp
View file @
2f268bc2
...
@@ -8,21 +8,21 @@
...
@@ -8,21 +8,21 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
eliminate_identity
::
apply
(
module
&
p
)
const
void
eliminate_identity
::
apply
(
module
&
m
)
const
{
{
auto
last
=
std
::
prev
(
p
.
end
());
auto
last
=
std
::
prev
(
m
.
end
());
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
m
))
{
{
// Skip the first instruction, since we always process the previous
// Skip the first instruction, since we always process the previous
// instruction
// instruction
if
(
ins
==
p
.
begin
())
if
(
ins
==
m
.
begin
())
continue
;
continue
;
const
auto
i
=
std
::
prev
(
ins
);
const
auto
i
=
std
::
prev
(
ins
);
if
(
i
->
name
()
==
"identity"
)
if
(
i
->
name
()
==
"identity"
)
{
{
p
.
replace_instruction
(
i
,
i
->
inputs
().
front
());
m
.
replace_instruction
(
i
,
i
->
inputs
().
front
());
p
.
move_instruction
(
i
,
p
.
end
());
m
.
move_instruction
(
i
,
m
.
end
());
}
}
if
(
ins
==
last
)
if
(
ins
==
last
)
{
{
...
@@ -31,7 +31,7 @@ void eliminate_identity::apply(module& p) const
...
@@ -31,7 +31,7 @@ void eliminate_identity::apply(module& p) const
const
instruction_ref
&
identity_input
=
ins
->
inputs
().
front
();
const
instruction_ref
&
identity_input
=
ins
->
inputs
().
front
();
if
(
identity_input
->
outputs
().
size
()
==
1
)
if
(
identity_input
->
outputs
().
size
()
==
1
)
{
{
p
.
move_instruction
(
identity_input
,
i
);
m
.
move_instruction
(
identity_input
,
i
);
// since this is the last instruction, removing it only
// since this is the last instruction, removing it only
// requires changing "last" and calling remove below
// requires changing "last" and calling remove below
last
=
std
::
prev
(
last
);
last
=
std
::
prev
(
last
);
...
@@ -40,7 +40,7 @@ void eliminate_identity::apply(module& p) const
...
@@ -40,7 +40,7 @@ void eliminate_identity::apply(module& p) const
break
;
break
;
}
}
}
}
p
.
remove_instructions
(
std
::
next
(
last
),
p
.
end
());
m
.
remove_instructions
(
std
::
next
(
last
),
m
.
end
());
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/adjust_allocation.hpp
View file @
2f268bc2
...
@@ -13,7 +13,7 @@ struct adjust_allocation
...
@@ -13,7 +13,7 @@ struct adjust_allocation
{
{
allocation_model
model
;
allocation_model
model
;
std
::
string
name
()
const
{
return
"adjust_allocation"
;
}
std
::
string
name
()
const
{
return
"adjust_allocation"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/analyze_streams.hpp
View file @
2f268bc2
...
@@ -16,7 +16,7 @@ struct stream_race
...
@@ -16,7 +16,7 @@ struct stream_race
instruction_ref
before
;
instruction_ref
before
;
};
};
std
::
vector
<
stream_race
>
analyze_streams
(
const
module
&
p
,
const
stream_model
&
m
);
std
::
vector
<
stream_race
>
analyze_streams
(
const
module
&
m
,
const
stream_model
&
strm
m
);
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/include/migraphx/auto_contiguous.hpp
View file @
2f268bc2
...
@@ -13,7 +13,7 @@ struct module;
...
@@ -13,7 +13,7 @@ struct module;
struct
auto_contiguous
struct
auto_contiguous
{
{
std
::
string
name
()
const
{
return
"auto_contiguous"
;
}
std
::
string
name
()
const
{
return
"auto_contiguous"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/check_context.hpp
View file @
2f268bc2
...
@@ -33,7 +33,7 @@ struct check_context
...
@@ -33,7 +33,7 @@ struct check_context
};
};
std
::
string
name
()
const
{
return
"check_context"
;
}
std
::
string
name
()
const
{
return
"check_context"
;
}
void
apply
(
module
&
p
)
const
{
p
.
insert_instruction
(
p
.
begin
(),
op
{});
}
void
apply
(
module
&
m
)
const
{
m
.
insert_instruction
(
m
.
begin
(),
op
{});
}
};
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/compile_src.hpp
View file @
2f268bc2
...
@@ -24,6 +24,7 @@ struct src_compiler
...
@@ -24,6 +24,7 @@ struct src_compiler
std
::
string
flags
=
""
;
std
::
string
flags
=
""
;
std
::
string
output
=
""
;
std
::
string
output
=
""
;
std
::
string
launcher
=
""
;
std
::
string
launcher
=
""
;
std
::
string
out_ext
=
".o"
;
std
::
function
<
fs
::
path
(
fs
::
path
)
>
process
=
nullptr
;
std
::
function
<
fs
::
path
(
fs
::
path
)
>
process
=
nullptr
;
std
::
vector
<
char
>
compile
(
const
std
::
vector
<
src_file
>&
srcs
)
const
;
std
::
vector
<
char
>
compile
(
const
std
::
vector
<
src_file
>&
srcs
)
const
;
};
};
...
...
src/include/migraphx/eliminate_allocation.hpp
View file @
2f268bc2
...
@@ -19,7 +19,7 @@ struct eliminate_allocation
...
@@ -19,7 +19,7 @@ struct eliminate_allocation
std
::
string
allocation_op
{};
std
::
string
allocation_op
{};
std
::
size_t
alignment
=
32
;
std
::
size_t
alignment
=
32
;
std
::
string
name
()
const
{
return
"eliminate_allocation"
;
}
std
::
string
name
()
const
{
return
"eliminate_allocation"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/eliminate_common_subexpression.hpp
View file @
2f268bc2
...
@@ -16,7 +16,7 @@ struct module;
...
@@ -16,7 +16,7 @@ struct module;
struct
eliminate_common_subexpression
struct
eliminate_common_subexpression
{
{
std
::
string
name
()
const
{
return
"eliminate_common_subexpression"
;
}
std
::
string
name
()
const
{
return
"eliminate_common_subexpression"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/eliminate_concat.hpp
View file @
2f268bc2
...
@@ -18,7 +18,7 @@ struct eliminate_concat
...
@@ -18,7 +18,7 @@ struct eliminate_concat
{
{
concat_optimization
concat_opt
;
concat_optimization
concat_opt
;
std
::
string
name
()
const
{
return
"eliminate_concat"
;
}
std
::
string
name
()
const
{
return
"eliminate_concat"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
Prev
1
2
3
4
5
6
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment