Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
2f268bc2
Commit
2f268bc2
authored
Jun 12, 2022
by
Paul
Browse files
Merge branch 'develop' into mlir-c
parents
f75c5a38
aa7ff911
Changes
205
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
222 additions
and
199 deletions
+222
-199
src/api/include/migraphx/migraphx.hpp
src/api/include/migraphx/migraphx.hpp
+150
-128
src/api/migraphx.py
src/api/migraphx.py
+4
-0
src/argument.cpp
src/argument.cpp
+0
-2
src/auto_contiguous.cpp
src/auto_contiguous.cpp
+8
-8
src/compile_src.cpp
src/compile_src.cpp
+1
-1
src/dead_code_elimination.cpp
src/dead_code_elimination.cpp
+6
-22
src/driver/marker_roctx.cpp
src/driver/marker_roctx.cpp
+1
-1
src/eliminate_allocation.cpp
src/eliminate_allocation.cpp
+4
-4
src/eliminate_common_subexpression.cpp
src/eliminate_common_subexpression.cpp
+5
-5
src/eliminate_concat.cpp
src/eliminate_concat.cpp
+6
-6
src/eliminate_contiguous.cpp
src/eliminate_contiguous.cpp
+21
-7
src/eliminate_identity.cpp
src/eliminate_identity.cpp
+8
-8
src/include/migraphx/adjust_allocation.hpp
src/include/migraphx/adjust_allocation.hpp
+1
-1
src/include/migraphx/analyze_streams.hpp
src/include/migraphx/analyze_streams.hpp
+1
-1
src/include/migraphx/auto_contiguous.hpp
src/include/migraphx/auto_contiguous.hpp
+1
-1
src/include/migraphx/check_context.hpp
src/include/migraphx/check_context.hpp
+1
-1
src/include/migraphx/compile_src.hpp
src/include/migraphx/compile_src.hpp
+1
-0
src/include/migraphx/eliminate_allocation.hpp
src/include/migraphx/eliminate_allocation.hpp
+1
-1
src/include/migraphx/eliminate_common_subexpression.hpp
src/include/migraphx/eliminate_common_subexpression.hpp
+1
-1
src/include/migraphx/eliminate_concat.hpp
src/include/migraphx/eliminate_concat.hpp
+1
-1
No files found.
src/api/include/migraphx/migraphx.hpp
View file @
2f268bc2
...
@@ -15,6 +15,16 @@ namespace migraphx {
...
@@ -15,6 +15,16 @@ namespace migraphx {
inline
namespace
api
{
// NOLINT
inline
namespace
api
{
// NOLINT
#endif
#endif
#ifdef __has_cpp_attribute
#if __has_cpp_attribute(deprecated)
#define MIGRAPHX_DEPRECATED(...) [[deprecated(__VA_ARGS__)]]
#endif
#endif
#ifndef MIGRAPHX_DEPRECATED
#define MIGRAPHX_DEPRECATED(...)
#endif
template
<
int
N
>
template
<
int
N
>
struct
rank
:
rank
<
N
-
1
>
struct
rank
:
rank
<
N
-
1
>
{
{
...
@@ -29,10 +39,7 @@ template <class T, class F, class... Ts>
...
@@ -29,10 +39,7 @@ template <class T, class F, class... Ts>
T
*
make
(
F
f
,
Ts
&&
...
xs
)
T
*
make
(
F
f
,
Ts
&&
...
xs
)
{
{
T
*
result
=
nullptr
;
T
*
result
=
nullptr
;
// cppcheck-suppress redundantInitialization
auto
e
=
f
(
&
result
,
std
::
forward
<
Ts
>
(
xs
)...);
// cppcheck-suppress redundantAssignment
// cppcheck-suppress unreadVariable
auto
e
=
f
(
&
result
,
std
::
forward
<
Ts
>
(
xs
)...);
if
(
e
!=
migraphx_status_success
)
if
(
e
!=
migraphx_status_success
)
throw
std
::
runtime_error
(
"Failed to call function"
);
throw
std
::
runtime_error
(
"Failed to call function"
);
return
result
;
return
result
;
...
@@ -41,9 +48,6 @@ T* make(F f, Ts&&... xs)
...
@@ -41,9 +48,6 @@ T* make(F f, Ts&&... xs)
template
<
class
F
,
class
...
Ts
>
template
<
class
F
,
class
...
Ts
>
void
call
(
F
f
,
Ts
&&
...
xs
)
void
call
(
F
f
,
Ts
&&
...
xs
)
{
{
// cppcheck-suppress redundantInitialization
// cppcheck-suppress redundantAssignment
// cppcheck-suppress unreadVariable
auto
e
=
f
(
std
::
forward
<
Ts
>
(
xs
)...);
auto
e
=
f
(
std
::
forward
<
Ts
>
(
xs
)...);
if
(
e
!=
migraphx_status_success
)
if
(
e
!=
migraphx_status_success
)
throw
std
::
runtime_error
(
"Failed to call function"
);
throw
std
::
runtime_error
(
"Failed to call function"
);
...
@@ -99,34 +103,22 @@ struct iota_iterator
...
@@ -99,34 +103,22 @@ struct iota_iterator
return
it
;
return
it
;
}
}
// TODO: operator->
// TODO: operator->
reference
operator
*
()
const
{
return
(
*
f
)(
index
);
}
reference
operator
*
()
const
{
return
f
(
index
);
}
};
template
<
class
F
,
class
Iterator
>
friend
iota_iterator
operator
+
(
iota_iterator
x
,
iota_iterator
y
)
inline
iota_iterator
<
F
,
Iterator
>
operator
+
(
iota_iterator
<
F
,
Iterator
>
x
,
{
iota_iterator
<
F
,
Iterator
>
y
)
return
iota_iterator
(
x
.
index
+
y
.
index
,
x
.
f
);
{
}
return
iota_iterator
<
F
,
Iterator
>
(
x
.
index
+
y
.
index
,
x
.
f
);
}
template
<
class
F
,
class
Iterator
>
friend
iota_iterator
operator
-
(
iota_iterator
x
,
iota_iterator
y
)
inline
iota_iterator
<
F
,
Iterator
>
operator
-
(
iota_iterator
<
F
,
Iterator
>
x
,
{
iota_iterator
<
F
,
Iterator
>
y
)
return
iota_iterator
(
x
.
index
-
y
.
index
,
x
.
f
);
{
}
return
iota_iterator
<
F
,
Iterator
>
(
x
.
index
-
y
.
index
,
x
.
f
);
}
template
<
class
F
,
class
Iterator
>
friend
bool
operator
==
(
iota_iterator
x
,
iota_iterator
y
)
{
return
x
.
index
==
y
.
index
;
}
inline
bool
operator
==
(
iota_iterator
<
F
,
Iterator
>
x
,
iota_iterator
<
F
,
Iterator
>
y
)
{
return
x
.
index
==
y
.
index
;
}
template
<
class
F
,
class
Iterator
>
friend
bool
operator
!=
(
iota_iterator
x
,
iota_iterator
y
)
{
return
x
.
index
!=
y
.
index
;
}
inline
bool
operator
!=
(
iota_iterator
<
F
,
Iterator
>
x
,
iota_iterator
<
F
,
Iterator
>
y
)
};
{
return
x
.
index
!=
y
.
index
;
}
template
<
class
Derived
>
template
<
class
Derived
>
struct
array_base
struct
array_base
...
@@ -136,8 +128,20 @@ struct array_base
...
@@ -136,8 +128,20 @@ struct array_base
template
<
class
T
>
template
<
class
T
>
using
value_type_t
=
decltype
(
std
::
declval
<
T
>
()[
0
]);
using
value_type_t
=
decltype
(
std
::
declval
<
T
>
()[
0
]);
struct
iterator_read
{
const
Derived
*
self
;
template
<
class
D
=
Derived
>
value_type_t
<
D
>
operator
()(
size_t
pidx
)
const
{
return
(
*
self
)[
pidx
];
}
};
template
<
class
T
>
template
<
class
T
>
using
iterator_t
=
iota_iterator
<
typename
T
::
iterator_read
>
;
using
iterator_t
=
iota_iterator
<
iterator_read
>
;
bool
empty
()
const
{
return
derived
().
size
()
==
0
;
}
template
<
class
D
=
Derived
>
template
<
class
D
=
Derived
>
value_type_t
<
D
>
front
()
const
value_type_t
<
D
>
front
()
const
...
@@ -154,13 +158,13 @@ struct array_base
...
@@ -154,13 +158,13 @@ struct array_base
template
<
class
D
=
Derived
>
template
<
class
D
=
Derived
>
iterator_t
<
D
>
begin
()
const
iterator_t
<
D
>
begin
()
const
{
{
return
{
0
,
{
derived
()
.
get_handle_ptr
()
}};
return
{
0
,
{
&
derived
()}};
}
}
template
<
class
D
=
Derived
>
template
<
class
D
=
Derived
>
iterator_t
<
D
>
end
()
const
iterator_t
<
D
>
end
()
const
{
{
return
{
derived
().
size
(),
{
derived
()
.
get_handle_ptr
()
}};
return
{
derived
().
size
(),
{
&
derived
()}};
}
}
};
};
...
@@ -200,9 +204,25 @@ struct borrow
...
@@ -200,9 +204,25 @@ struct borrow
{
{
};
};
template
<
class
T
>
struct
share
{
share
(
std
::
shared_ptr
<
T
>
p
)
:
ptr
(
std
::
move
(
p
))
{}
template
<
class
U
>
std
::
shared_ptr
<
U
>
alias
(
U
*
p
)
const
{
return
std
::
shared_ptr
<
U
>
{
ptr
,
p
};
}
private:
std
::
shared_ptr
<
T
>
ptr
;
};
template
<
class
Derived
,
class
T
,
class
D
,
D
Deleter
,
class
A
,
A
Assigner
>
template
<
class
Derived
,
class
T
,
class
D
,
D
Deleter
,
class
A
,
A
Assigner
>
struct
handle_base
:
handle_lookup
<
Derived
,
std
::
remove_cv_t
<
T
>>
struct
handle_base
:
handle_lookup
<
Derived
,
std
::
remove_cv_t
<
T
>>
{
{
using
handle_type
=
T
;
handle_base
()
:
m_handle
(
nullptr
)
{}
handle_base
()
:
m_handle
(
nullptr
)
{}
template
<
class
F
,
class
...
Ts
>
template
<
class
F
,
class
...
Ts
>
void
make_handle
(
F
f
,
Ts
&&
...
xs
)
void
make_handle
(
F
f
,
Ts
&&
...
xs
)
...
@@ -231,6 +251,14 @@ struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>>
...
@@ -231,6 +251,14 @@ struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>>
m_handle
=
std
::
shared_ptr
<
U
>
{
ptr
,
[](
U
*
)
{}};
m_handle
=
std
::
shared_ptr
<
U
>
{
ptr
,
[](
U
*
)
{}};
}
}
template
<
class
U
,
class
V
>
void
set_handle
(
U
*
ptr
,
share
<
V
>
b
)
{
m_handle
=
std
::
shared_ptr
<
T
>
{
ptr
,
[
b
](
U
*
)
{}};
}
share
<
T
>
share_handle
()
const
{
return
{
m_handle
};
}
template
<
class
U
>
template
<
class
U
>
void
assign_to_handle
(
U
*
x
)
void
assign_to_handle
(
U
*
x
)
{
{
...
@@ -241,6 +269,17 @@ struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>>
...
@@ -241,6 +269,17 @@ struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>>
std
::
shared_ptr
<
T
>
m_handle
;
std
::
shared_ptr
<
T
>
m_handle
;
};
};
// NOLINTNEXTLINE
#define MIGRAPHX_HANDLE_CONSTRUCTOR(name) \
template <class HandleType, \
class Lifetime, \
class = \
typename std::enable_if<std::is_convertible<HandleType*, handle_type*>{}>::type> \
name(HandleType* p, Lifetime lifetime) \
{ \
this->set_handle(p, std::move(lifetime)); \
}
template
<
class
Base
>
template
<
class
Base
>
struct
interface_base
:
Base
struct
interface_base
:
Base
{
{
...
@@ -269,6 +308,7 @@ struct interface_base : Base
...
@@ -269,6 +308,7 @@ struct interface_base : Base
T
**
y
=
reinterpret_cast
<
T
**>
(
out
);
T
**
y
=
reinterpret_cast
<
T
**>
(
out
);
T
*
x
=
reinterpret_cast
<
T
*>
(
input
);
T
*
x
=
reinterpret_cast
<
T
*>
(
input
);
assert
(
x
!=
nullptr
and
y
!=
nullptr
and
*
y
==
nullptr
);
assert
(
x
!=
nullptr
and
y
!=
nullptr
and
*
y
==
nullptr
);
// cppcheck-suppress useSmartPointer
*
y
=
new
T
(
*
x
);
// NOLINT
*
y
=
new
T
(
*
x
);
// NOLINT
});
});
};
};
...
@@ -398,11 +438,10 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
...
@@ -398,11 +438,10 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
{
{
shape
()
{}
shape
()
{}
MIGRAPHX_DEPRECATED
(
"Contructor without lifetime annotation is deprecated."
)
shape
(
const
migraphx_shape
*
p
)
{
this
->
set_handle
(
p
,
borrow
{});
}
shape
(
const
migraphx_shape
*
p
)
{
this
->
set_handle
(
p
,
borrow
{});
}
shape
(
migraphx_shape
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
shape
);
shape
(
migraphx_shape
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
/// Construct a scalar shape
/// Construct a scalar shape
shape
(
migraphx_shape_datatype_t
type
)
shape
(
migraphx_shape_datatype_t
type
)
...
@@ -479,10 +518,9 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
...
@@ -479,10 +518,9 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
{
{
argument
()
{}
argument
()
{}
argument
(
migraphx_argument
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
argument
);
argument
(
migraphx_argument
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
MIGRAPHX_DEPRECATED
(
"Contructor without lifetime annotation is deprecated."
)
argument
(
const
migraphx_argument
*
p
)
{
this
->
set_handle
(
p
,
borrow
{});
}
argument
(
const
migraphx_argument
*
p
)
{
this
->
set_handle
(
p
,
borrow
{});
}
argument
(
shape
pshape
,
void
*
pbuffer
)
argument
(
shape
pshape
,
void
*
pbuffer
)
...
@@ -494,7 +532,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
...
@@ -494,7 +532,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
{
{
const_migraphx_shape_t
pout
;
const_migraphx_shape_t
pout
;
call
(
&
migraphx_argument_shape
,
&
pout
,
this
->
get_handle_ptr
());
call
(
&
migraphx_argument_shape
,
&
pout
,
this
->
get_handle_ptr
());
return
{
pout
};
return
{
pout
,
this
->
share_handle
()
};
}
}
char
*
data
()
const
char
*
data
()
const
...
@@ -526,9 +564,7 @@ struct target : MIGRAPHX_HANDLE_BASE(target)
...
@@ -526,9 +564,7 @@ struct target : MIGRAPHX_HANDLE_BASE(target)
{
{
target
()
{}
target
()
{}
target
(
migraphx_target
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
target
);
target
(
migraphx_target
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
/// Construct a target from its name
/// Construct a target from its name
target
(
const
char
*
name
)
{
this
->
make_handle
(
&
migraphx_target_create
,
name
);
}
target
(
const
char
*
name
)
{
this
->
make_handle
(
&
migraphx_target_create
,
name
);
}
...
@@ -538,15 +574,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
...
@@ -538,15 +574,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
{
{
program_parameter_shapes
()
{}
program_parameter_shapes
()
{}
program_parameter_shapes
(
migraphx_program_parameter_shapes
*
p
,
own
)
MIGRAPHX_HANDLE_CONSTRUCTOR
(
program_parameter_shapes
);
{
this
->
set_handle
(
p
,
own
{});
}
program_parameter_shapes
(
migraphx_program_parameter_shapes
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
size_t
size
()
const
size_t
size
()
const
{
{
...
@@ -559,7 +587,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
...
@@ -559,7 +587,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
{
{
const_migraphx_shape_t
pout
;
const_migraphx_shape_t
pout
;
call
(
&
migraphx_program_parameter_shapes_get
,
&
pout
,
this
->
get_handle_ptr
(),
pname
);
call
(
&
migraphx_program_parameter_shapes_get
,
&
pout
,
this
->
get_handle_ptr
(),
pname
);
return
{
pout
};
return
{
pout
,
this
->
share_handle
()
};
}
}
std
::
vector
<
const
char
*>
names
()
const
std
::
vector
<
const
char
*>
names
()
const
...
@@ -576,10 +604,9 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
...
@@ -576,10 +604,9 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
/// A class to construct the inputs parameters for a program
/// A class to construct the inputs parameters for a program
struct
program_parameters
:
MIGRAPHX_HANDLE_BASE
(
program_parameters
)
struct
program_parameters
:
MIGRAPHX_HANDLE_BASE
(
program_parameters
)
{
{
program_parameters
(
migraphx_program_parameters
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
program_parameters
);
program_parameters
(
migraphx_program_parameters
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
MIGRAPHX_DEPRECATED
(
"Contructor without lifetime annotation is deprecated."
)
program_parameters
(
migraphx_program_parameters
*
p
)
{
this
->
set_handle
(
p
,
borrow
{});
}
program_parameters
(
migraphx_program_parameters
*
p
)
{
this
->
set_handle
(
p
,
borrow
{});
}
program_parameters
()
{
this
->
make_handle
(
&
migraphx_program_parameters_create
);
}
program_parameters
()
{
this
->
make_handle
(
&
migraphx_program_parameters_create
);
}
...
@@ -604,9 +631,7 @@ struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters)
...
@@ -604,9 +631,7 @@ struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters)
struct
arguments
:
MIGRAPHX_HANDLE_BASE
(
arguments
),
array_base
<
arguments
>
struct
arguments
:
MIGRAPHX_HANDLE_BASE
(
arguments
),
array_base
<
arguments
>
{
{
arguments
(
migraphx_arguments
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
arguments
);
arguments
(
migraphx_arguments
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
size_t
size
()
const
size_t
size
()
const
{
{
...
@@ -619,27 +644,13 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments>
...
@@ -619,27 +644,13 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments>
{
{
const_migraphx_argument_t
pout
;
const_migraphx_argument_t
pout
;
call
(
&
migraphx_arguments_get
,
&
pout
,
this
->
get_handle_ptr
(),
pidx
);
call
(
&
migraphx_arguments_get
,
&
pout
,
this
->
get_handle_ptr
(),
pidx
);
return
{
pout
};
return
{
pout
,
this
->
share_handle
()
};
}
}
struct
iterator_read
{
migraphx_arguments
*
self
;
argument
operator
()(
size_t
pidx
)
const
{
const_migraphx_argument_t
pout
;
call
(
&
migraphx_arguments_get
,
&
pout
,
self
,
pidx
);
return
{
pout
};
}
};
};
};
struct
shapes
:
MIGRAPHX_HANDLE_BASE
(
shapes
),
array_base
<
shapes
>
struct
shapes
:
MIGRAPHX_HANDLE_BASE
(
shapes
),
array_base
<
shapes
>
{
{
shapes
(
migraphx_shapes
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
shapes
);
shapes
(
migraphx_shapes
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
size_t
size
()
const
size_t
size
()
const
{
{
...
@@ -652,26 +663,13 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
...
@@ -652,26 +663,13 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
{
{
const_migraphx_shape_t
pout
;
const_migraphx_shape_t
pout
;
call
(
&
migraphx_shapes_get
,
&
pout
,
this
->
get_handle_ptr
(),
pidx
);
call
(
&
migraphx_shapes_get
,
&
pout
,
this
->
get_handle_ptr
(),
pidx
);
return
{
pout
};
return
{
pout
,
this
->
share_handle
()
};
}
}
struct
iterator_read
{
migraphx_shapes
*
self
;
shape
operator
()(
size_t
pidx
)
const
{
const_migraphx_shape_t
pout
;
call
(
&
migraphx_shapes_get
,
&
pout
,
self
,
pidx
);
return
{
pout
};
}
};
};
};
struct
operation
:
MIGRAPHX_HANDLE_BASE
(
operation
)
struct
operation
:
MIGRAPHX_HANDLE_BASE
(
operation
)
{
{
operation
(
migraphx_operation
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
operation
);
operation
(
migraphx_operation
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
template
<
class
...
Ts
>
template
<
class
...
Ts
>
operation
(
const
char
*
name
,
const
char
*
attributes
=
nullptr
,
Ts
...
xs
)
operation
(
const
char
*
name
,
const
char
*
attributes
=
nullptr
,
Ts
...
xs
)
...
@@ -689,15 +687,12 @@ struct operation : MIGRAPHX_HANDLE_BASE(operation)
...
@@ -689,15 +687,12 @@ struct operation : MIGRAPHX_HANDLE_BASE(operation)
struct
instruction
:
MIGRAPHX_CONST_HANDLE_BASE
(
instruction
)
struct
instruction
:
MIGRAPHX_CONST_HANDLE_BASE
(
instruction
)
{
{
instruction
(
migraphx_instruction
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
instruction
);
};
};
struct
instructions
:
MIGRAPHX_HANDLE_BASE
(
instructions
)
struct
instructions
:
MIGRAPHX_HANDLE_BASE
(
instructions
)
{
{
MIGRAPHX_HANDLE_CONSTRUCTOR
(
instructions
);
instructions
(
migraphx_instructions
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
instructions
(
migraphx_instructions
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
template
<
class
...
Ts
>
template
<
class
...
Ts
>
instructions
(
Ts
...
xs
)
instructions
(
Ts
...
xs
)
...
@@ -711,33 +706,36 @@ struct module;
...
@@ -711,33 +706,36 @@ struct module;
struct
modules
:
MIGRAPHX_HANDLE_BASE
(
modules
)
struct
modules
:
MIGRAPHX_HANDLE_BASE
(
modules
)
{
{
MIGRAPHX_HANDLE_CONSTRUCTOR
(
modules
);
modules
(
migraphx_modules
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
modules
(
migraphx_modules
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
template
<
class
...
Ts
>
template
<
class
...
Ts
>
modules
(
Ts
...
xs
)
modules
(
Ts
...
xs
)
{
{
std
::
array
<
migraphx_module_t
,
sizeof
...(
Ts
)
>
a
=
{
xs
.
mm
...};
std
::
array
<
migraphx_module_t
,
sizeof
...(
Ts
)
>
a
=
{
xs
.
get_handle_ptr
()
...};
this
->
make_handle
(
&
migraphx_modules_create
,
a
.
data
(),
a
.
size
());
this
->
make_handle
(
&
migraphx_modules_create
,
a
.
data
(),
a
.
size
());
}
}
};
};
struct
module
struct
module
{
{
migraphx_module_t
mm
;
MIGRAPHX_DEPRECATED
(
"Constructor without lifetime annotation is deprecated."
)
module
(
migraphx_module
*
m
)
:
mm
(
std
::
shared_ptr
<
migraphx_module
*>
(),
m
)
{}
module
(
const
migraphx_module
_t
&
m
)
:
mm
(
m
)
{}
module
(
migraphx_module
*
m
,
borrow
)
:
mm
(
std
::
shared_ptr
<
migraphx_module
*>
(),
m
)
{}
void
print
()
const
{
call
(
&
migraphx_module_print
,
mm
);
}
template
<
class
T
>
module
(
migraphx_module
*
m
,
share
<
T
>
b
)
:
mm
(
b
.
alias
(
m
))
{
}
void
print
()
const
{
call
(
&
migraphx_module_print
,
mm
.
get
());
}
instruction
add_instruction
(
const
migraphx
::
operation
&
op
,
const
migraphx
::
instructions
&
args
)
instruction
add_instruction
(
const
migraphx
::
operation
&
op
,
const
migraphx
::
instructions
&
args
)
{
{
migraphx_instruction_t
op_ins
;
migraphx_instruction_t
op_ins
;
call
(
&
migraphx_module_add_instruction
,
call
(
&
migraphx_module_add_instruction
,
&
op_ins
,
&
op_ins
,
mm
,
mm
.
get
()
,
op
.
get_handle_ptr
(),
op
.
get_handle_ptr
(),
args
.
get_handle_ptr
());
args
.
get_handle_ptr
());
return
instruction
(
op_ins
,
own
{});
return
instruction
(
op_ins
,
own
{});
...
@@ -750,40 +748,72 @@ struct module
...
@@ -750,40 +748,72 @@ struct module
migraphx_instruction_t
op_ins
;
migraphx_instruction_t
op_ins
;
call
(
&
migraphx_module_add_instruction_with_mod_args
,
call
(
&
migraphx_module_add_instruction_with_mod_args
,
&
op_ins
,
&
op_ins
,
mm
,
mm
.
get
()
,
op
.
get_handle_ptr
(),
op
.
get_handle_ptr
(),
args
.
get_handle_ptr
(),
args
.
get_handle_ptr
(),
module_args
.
get_handle_ptr
());
module_args
.
get_handle_ptr
());
return
instruction
(
op_ins
,
own
{});
return
instruction
(
op_ins
,
own
{});
}
}
template
<
typename
T
>
instruction
add_literal
(
const
migraphx
::
shape
&
s
,
T
*
buffer
)
{
migraphx_instruction_t
literal_ins
;
const
auto
*
buffer_ptr
=
reinterpret_cast
<
const
char
*>
(
buffer
);
call
(
&
migraphx_module_add_literal
,
&
literal_ins
,
mm
.
get
(),
s
.
get_handle_ptr
(),
buffer_ptr
);
return
instruction
(
literal_ins
,
own
{});
}
instruction
add_parameter
(
const
std
::
string
&
name
,
shape
s
)
instruction
add_parameter
(
const
std
::
string
&
name
,
shape
s
)
{
{
migraphx_instruction_t
param_ins
;
migraphx_instruction_t
param_ins
;
call
(
&
migraphx_module_add_parameter
,
&
param_ins
,
mm
,
name
.
c_str
(),
s
.
get_handle_ptr
());
call
(
&
migraphx_module_add_parameter
,
&
param_ins
,
mm
.
get
(),
name
.
c_str
(),
s
.
get_handle_ptr
());
return
instruction
(
param_ins
,
own
{});
return
instruction
(
param_ins
,
own
{});
}
}
instruction
add_return
(
const
migraphx
::
instructions
&
args
)
instruction
add_return
(
const
migraphx
::
instructions
&
args
)
{
{
migraphx_instruction_t
ret_ins
;
migraphx_instruction_t
ret_ins
;
call
(
&
migraphx_module_add_return
,
&
ret_ins
,
mm
,
args
.
get_handle_ptr
());
call
(
&
migraphx_module_add_return
,
&
ret_ins
,
mm
.
get
()
,
args
.
get_handle_ptr
());
return
instruction
(
ret_ins
,
own
{});
return
instruction
(
ret_ins
,
own
{});
}
}
migraphx_module_t
get_handle_ptr
()
const
{
return
mm
.
get
();
}
private:
std
::
shared_ptr
<
migraphx_module
>
mm
;
};
};
struct
context
struct
context
{
{
migraphx_context
_t
ctx
;
context
(
migraphx_context
*
p
,
borrow
)
:
ctx
(
std
::
shared_ptr
<
migraphx_context
*>
(),
p
)
{}
void
finish
()
const
{
call
(
&
migraphx_context_finish
,
ctx
);
}
template
<
class
T
>
context
(
migraphx_context
*
p
,
share
<
T
>
b
)
:
ctx
(
b
.
alias
(
p
))
{
}
void
finish
()
const
{
call
(
&
migraphx_context_finish
,
ctx
.
get
());
}
template
<
class
T
>
T
get_queue
()
{
void
*
out
;
call
(
&
migraphx_context_get_queue
,
&
out
,
ctx
.
get
());
// TODO: check type here
return
reinterpret_cast
<
T
>
(
out
);
}
private:
std
::
shared_ptr
<
migraphx_context
>
ctx
;
};
};
struct
compile_options
:
MIGRAPHX_HANDLE_BASE
(
compile_options
)
struct
compile_options
:
MIGRAPHX_HANDLE_BASE
(
compile_options
)
{
{
compile_options
()
{
this
->
make_handle
(
&
migraphx_compile_options_create
);
}
compile_options
()
{
this
->
make_handle
(
&
migraphx_compile_options_create
);
}
compile_options
(
migraphx_compile_options
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
());
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
compile_options
);
/// For targets with offloaded memory(such as the gpu), this will insert
/// For targets with offloaded memory(such as the gpu), this will insert
/// instructions during compilation to copy the input parameters to the
/// instructions during compilation to copy the input parameters to the
...
@@ -807,9 +837,7 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
...
@@ -807,9 +837,7 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
{
{
program
()
{
this
->
make_handle
(
&
migraphx_program_create
);
}
program
()
{
this
->
make_handle
(
&
migraphx_program_create
);
}
program
(
migraphx_program
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
program
);
program
(
migraphx_program
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
/// Compile the program for a specific target to be ran on
/// Compile the program for a specific target to be ran on
void
compile
(
const
target
&
ptarget
,
const
compile_options
&
poptions
)
const
void
compile
(
const
target
&
ptarget
,
const
compile_options
&
poptions
)
const
...
@@ -872,21 +900,21 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
...
@@ -872,21 +900,21 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
{
{
migraphx_module_t
p_modu
;
migraphx_module_t
p_modu
;
call
(
&
migraphx_program_get_main_module
,
&
p_modu
,
this
->
get_handle_ptr
());
call
(
&
migraphx_program_get_main_module
,
&
p_modu
,
this
->
get_handle_ptr
());
return
module
{
p_modu
};
return
module
{
p_modu
,
this
->
share_handle
()
};
}
}
context
experimental_get_context
()
context
experimental_get_context
()
{
{
migraphx_context_t
ctx
;
migraphx_context_t
ctx
;
call
(
&
migraphx_program_experimental_get_context
,
&
ctx
,
this
->
get_handle_ptr
());
call
(
&
migraphx_program_experimental_get_context
,
&
ctx
,
this
->
get_handle_ptr
());
return
context
{
ctx
};
return
context
{
ctx
,
this
->
share_handle
()
};
}
}
module
create_module
(
const
std
::
string
&
name
)
module
create_module
(
const
std
::
string
&
name
)
{
{
migraphx_module_t
p_modu
;
migraphx_module_t
p_modu
;
call
(
&
migraphx_program_create_module
,
&
p_modu
,
this
->
get_handle_ptr
(),
name
.
data
());
call
(
&
migraphx_program_create_module
,
&
p_modu
,
this
->
get_handle_ptr
(),
name
.
data
());
return
module
{
p_modu
};
return
module
{
p_modu
,
this
->
share_handle
()
};
}
}
friend
bool
operator
!=
(
const
program
&
px
,
const
program
&
py
)
{
return
!
(
px
==
py
);
}
friend
bool
operator
!=
(
const
program
&
px
,
const
program
&
py
)
{
return
!
(
px
==
py
);
}
...
@@ -895,10 +923,9 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
...
@@ -895,10 +923,9 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
// options for migraphx file format options
// options for migraphx file format options
struct
file_options
:
MIGRAPHX_HANDLE_BASE
(
file_options
)
struct
file_options
:
MIGRAPHX_HANDLE_BASE
(
file_options
)
{
{
MIGRAPHX_HANDLE_CONSTRUCTOR
(
file_options
);
file_options
()
{
this
->
make_handle
(
&
migraphx_file_options_create
);
}
file_options
()
{
this
->
make_handle
(
&
migraphx_file_options_create
);
}
file_options
(
migraphx_file_options
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
());
}
// set file format
// set file format
void
set_file_format
(
const
char
*
format
)
void
set_file_format
(
const
char
*
format
)
{
{
...
@@ -938,7 +965,7 @@ struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options)
...
@@ -938,7 +965,7 @@ struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options)
{
{
onnx_options
()
{
this
->
make_handle
(
&
migraphx_onnx_options_create
);
}
onnx_options
()
{
this
->
make_handle
(
&
migraphx_onnx_options_create
);
}
onnx_options
(
migraphx_onnx_options
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
onnx_options
);
/// Make onnx parser treat an inputs with a certain dimensions
/// Make onnx parser treat an inputs with a certain dimensions
void
set_input_parameter_shape
(
const
std
::
string
&
name
,
std
::
vector
<
std
::
size_t
>
dim
)
void
set_input_parameter_shape
(
const
std
::
string
&
name
,
std
::
vector
<
std
::
size_t
>
dim
)
...
@@ -1020,7 +1047,7 @@ struct tf_options : MIGRAPHX_HANDLE_BASE(tf_options)
...
@@ -1020,7 +1047,7 @@ struct tf_options : MIGRAPHX_HANDLE_BASE(tf_options)
{
{
tf_options
()
{
this
->
make_handle
(
&
migraphx_tf_options_create
);
}
tf_options
()
{
this
->
make_handle
(
&
migraphx_tf_options_create
);
}
tf_options
(
migraphx_tf_options
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
tf_options
);
/// Make tf parser treat an inputs with a certain dimensions
/// Make tf parser treat an inputs with a certain dimensions
void
set_input_parameter_shape
(
const
std
::
string
&
name
,
std
::
vector
<
std
::
size_t
>
dim
)
void
set_input_parameter_shape
(
const
std
::
string
&
name
,
std
::
vector
<
std
::
size_t
>
dim
)
...
@@ -1073,7 +1100,7 @@ struct quantize_op_names : MIGRAPHX_HANDLE_BASE(quantize_op_names)
...
@@ -1073,7 +1100,7 @@ struct quantize_op_names : MIGRAPHX_HANDLE_BASE(quantize_op_names)
{
{
quantize_op_names
()
{
this
->
make_handle
(
&
migraphx_quantize_op_names_create
);
}
quantize_op_names
()
{
this
->
make_handle
(
&
migraphx_quantize_op_names_create
);
}
quantize_op_names
(
migraphx_quantize_op_names
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
quantize_op_names
);
void
add
(
const
std
::
string
&
name
)
void
add
(
const
std
::
string
&
name
)
{
{
...
@@ -1098,12 +1125,7 @@ struct quantize_int8_options : MIGRAPHX_HANDLE_BASE(quantize_int8_options)
...
@@ -1098,12 +1125,7 @@ struct quantize_int8_options : MIGRAPHX_HANDLE_BASE(quantize_int8_options)
{
{
quantize_int8_options
()
{
this
->
make_handle
(
&
migraphx_quantize_int8_options_create
);
}
quantize_int8_options
()
{
this
->
make_handle
(
&
migraphx_quantize_int8_options_create
);
}
quantize_int8_options
(
migraphx_quantize_int8_options
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
quantize_int8_options
);
quantize_int8_options
(
migraphx_quantize_int8_options
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
/// Add an operator that should be quantized
/// Add an operator that should be quantized
void
add_op_name
(
const
std
::
string
&
name
)
void
add_op_name
(
const
std
::
string
&
name
)
...
...
src/api/migraphx.py
View file @
2f268bc2
...
@@ -212,6 +212,9 @@ def module(h):
...
@@ -212,6 +212,9 @@ def module(h):
module_refs
=
'std::vector<migraphx::module*>'
),
module_refs
=
'std::vector<migraphx::module*>'
),
fname
=
'add_instruction'
,
fname
=
'add_instruction'
,
returns
=
'migraphx::instruction_ref'
)
returns
=
'migraphx::instruction_ref'
)
h
.
method
(
'add_literal'
,
api
.
params
(
shape
=
'const migraphx::shape&'
,
buffer
=
'const char*'
),
returns
=
'migraphx::instruction_ref'
)
h
.
method
(
'add_parameter'
,
h
.
method
(
'add_parameter'
,
api
.
params
(
name
=
'const char*'
,
shape
=
'const migraphx::shape&'
),
api
.
params
(
name
=
'const char*'
,
shape
=
'const migraphx::shape&'
),
returns
=
'migraphx::instruction_ref'
)
returns
=
'migraphx::instruction_ref'
)
...
@@ -403,6 +406,7 @@ api.add_function('migraphx_quantize_int8',
...
@@ -403,6 +406,7 @@ api.add_function('migraphx_quantize_int8',
@
auto_handle
(
ref
=
True
)
@
auto_handle
(
ref
=
True
)
def
context
(
h
):
def
context
(
h
):
h
.
method
(
'finish'
,
const
=
True
)
h
.
method
(
'finish'
,
const
=
True
)
h
.
method
(
'get_queue'
,
returns
=
'void*'
,
fname
=
'get_queue().unsafe_get'
)
@
api
.
interface
(
'migraphx_experimental_custom_op'
,
@
api
.
interface
(
'migraphx_experimental_custom_op'
,
...
...
src/argument.cpp
View file @
2f268bc2
...
@@ -29,7 +29,6 @@ void argument::assign_buffer(std::function<char*()> d)
...
@@ -29,7 +29,6 @@ void argument::assign_buffer(std::function<char*()> d)
// Collect all shapes
// Collect all shapes
std
::
unordered_map
<
std
::
size_t
,
shape
>
shapes
;
std
::
unordered_map
<
std
::
size_t
,
shape
>
shapes
;
{
{
// cppcheck-suppress variableScope
std
::
size_t
i
=
0
;
std
::
size_t
i
=
0
;
fix
([
&
](
auto
self
,
auto
ss
)
{
fix
([
&
](
auto
self
,
auto
ss
)
{
if
(
ss
.
sub_shapes
().
empty
())
if
(
ss
.
sub_shapes
().
empty
())
...
@@ -60,7 +59,6 @@ void argument::assign_buffer(std::function<char*()> d)
...
@@ -60,7 +59,6 @@ void argument::assign_buffer(std::function<char*()> d)
}
}
assert
(
offset
==
s
.
bytes
());
assert
(
offset
==
s
.
bytes
());
// cppcheck-suppress variableScope
std
::
size_t
i
=
0
;
std
::
size_t
i
=
0
;
m_data
=
fix
<
data_t
>
([
&
](
auto
self
,
auto
ss
)
{
m_data
=
fix
<
data_t
>
([
&
](
auto
self
,
auto
ss
)
{
data_t
result
;
data_t
result
;
...
...
src/auto_contiguous.cpp
View file @
2f268bc2
...
@@ -8,10 +8,10 @@
...
@@ -8,10 +8,10 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
auto_contiguous
::
apply
(
module
&
p
)
const
void
auto_contiguous
::
apply
(
module
&
m
)
const
{
{
std
::
string
key
=
"require_std_shape"
;
std
::
string
key
=
"require_std_shape"
;
for
(
auto
ins
:
reverse_iterator_for
(
p
))
for
(
auto
ins
:
reverse_iterator_for
(
m
))
{
{
auto
&&
attr
=
ins
->
get_operator
().
attributes
();
auto
&&
attr
=
ins
->
get_operator
().
attributes
();
if
((
attr
.
get
(
key
,
false
)))
if
((
attr
.
get
(
key
,
false
)))
...
@@ -23,18 +23,18 @@ void auto_contiguous::apply(module& p) const
...
@@ -23,18 +23,18 @@ void auto_contiguous::apply(module& p) const
{
{
return
in
;
return
in
;
}
}
return
p
.
insert_instruction
(
ins
,
make_op
(
"contiguous"
),
in
);
return
m
.
insert_instruction
(
ins
,
make_op
(
"contiguous"
),
in
);
});
});
if
(
new_args
!=
args
)
if
(
new_args
!=
args
)
{
{
p
.
replace_instruction
(
ins
,
ins
->
get_operator
(),
new_args
);
m
.
replace_instruction
(
ins
,
ins
->
get_operator
(),
new_args
);
}
}
}
}
}
}
auto
last
=
std
::
prev
(
p
.
end
());
auto
last
=
std
::
prev
(
m
.
end
());
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
m
))
{
{
// for last instruction that is NOT a return
// for last instruction that is NOT a return
if
(
ins
->
outputs
().
empty
()
and
ins
!=
last
)
if
(
ins
->
outputs
().
empty
()
and
ins
!=
last
)
...
@@ -42,8 +42,8 @@ void auto_contiguous::apply(module& p) const
...
@@ -42,8 +42,8 @@ void auto_contiguous::apply(module& p) const
shape
s
=
ins
->
get_shape
();
shape
s
=
ins
->
get_shape
();
if
(
not
s
.
standard
()
and
s
.
elements
()
!=
0
)
if
(
not
s
.
standard
()
and
s
.
elements
()
!=
0
)
{
{
auto
c
=
p
.
insert_instruction
(
std
::
next
(
ins
),
make_op
(
"contiguous"
),
ins
);
auto
c
=
m
.
insert_instruction
(
std
::
next
(
ins
),
make_op
(
"contiguous"
),
ins
);
p
.
replace_instruction
(
ins
,
c
);
m
.
replace_instruction
(
ins
,
c
);
}
}
}
}
}
}
...
...
src/compile_src.cpp
View file @
2f268bc2
...
@@ -28,7 +28,7 @@ std::vector<char> src_compiler::compile(const std::vector<src_file>& srcs) const
...
@@ -28,7 +28,7 @@ std::vector<char> src_compiler::compile(const std::vector<src_file>& srcs) const
{
{
params
+=
" "
+
src
.
path
.
filename
().
string
();
params
+=
" "
+
src
.
path
.
filename
().
string
();
if
(
out
.
empty
())
if
(
out
.
empty
())
out
=
src
.
path
.
stem
().
string
()
+
".o"
;
out
=
src
.
path
.
stem
().
string
()
+
out_ext
;
}
}
}
}
...
...
src/dead_code_elimination.cpp
View file @
2f268bc2
...
@@ -9,26 +9,6 @@
...
@@ -9,26 +9,6 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
class
Range
,
class
Iterator
>
std
::
ptrdiff_t
bidistance
(
const
Range
&
r
,
Iterator
start
,
Iterator
last
)
{
auto
start_forward
=
start
;
auto
start_backwards
=
start
;
std
::
size_t
n
=
0
;
while
(
start_forward
!=
last
and
start_backwards
!=
last
)
{
n
++
;
if
(
start_forward
!=
r
.
end
())
start_forward
++
;
if
(
start_backwards
!=
r
.
begin
())
start_backwards
--
;
}
if
(
start_forward
==
last
)
return
n
;
else
return
-
n
;
}
void
dead_code_elimination
::
apply
(
program
&
p
)
const
{
p
.
remove_unused_modules
();
}
void
dead_code_elimination
::
apply
(
program
&
p
)
const
{
p
.
remove_unused_modules
();
}
void
dead_code_elimination
::
apply
(
module
&
m
)
const
void
dead_code_elimination
::
apply
(
module
&
m
)
const
...
@@ -48,17 +28,21 @@ void dead_code_elimination::apply(module& m) const
...
@@ -48,17 +28,21 @@ void dead_code_elimination::apply(module& m) const
if
(
i
->
get_shape
().
elements
()
==
0
and
i
->
name
().
front
()
!=
'@'
and
if
(
i
->
get_shape
().
elements
()
==
0
and
i
->
name
().
front
()
!=
'@'
and
i
->
name
()
!=
"undefined"
and
i
->
name
()
!=
"identity"
)
i
->
name
()
!=
"undefined"
and
i
->
name
()
!=
"identity"
)
continue
;
continue
;
assert
(
bidistance
(
m
,
i
,
last
)
>
0
);
assert
(
std
::
distance
(
m
.
begin
(),
i
)
<=
std
::
distance
(
m
.
begin
(),
last
));
std
::
unordered_set
<
instruction_ref
>
visited
;
fix
([
&
](
auto
self
,
auto
leaf
)
{
fix
([
&
](
auto
self
,
auto
leaf
)
{
if
(
not
m
.
has_instruction
(
leaf
))
if
(
not
m
.
has_instruction
(
leaf
))
return
;
return
;
if
(
leaf
->
outputs
().
empty
())
if
(
leaf
->
outputs
().
empty
())
{
{
// Dont visit inputs twice
if
(
not
visited
.
insert
(
leaf
).
second
)
return
;
std
::
unordered_set
<
instruction_ref
>
args
(
leaf
->
inputs
().
begin
(),
std
::
unordered_set
<
instruction_ref
>
args
(
leaf
->
inputs
().
begin
(),
leaf
->
inputs
().
end
());
leaf
->
inputs
().
end
());
leaf
->
clear_arguments
();
leaf
->
clear_arguments
();
assert
(
bi
distance
(
m
,
last
,
leaf
)
<
0
);
assert
(
std
::
distance
(
m
.
begin
(),
leaf
)
<
std
::
distance
(
m
.
begin
(),
last
)
);
assert
(
leaf
!=
ins
);
assert
(
leaf
!=
ins
);
if
(
leaf
->
name
()
!=
"@param"
)
if
(
leaf
->
name
()
!=
"@param"
)
m
.
move_instruction
(
leaf
,
m
.
end
());
m
.
move_instruction
(
leaf
,
m
.
end
());
...
...
src/driver/marker_roctx.cpp
View file @
2f268bc2
...
@@ -17,7 +17,7 @@ class marker_roctx
...
@@ -17,7 +17,7 @@ class marker_roctx
std
::
function
<
int
(
const
char
*
)
>
sym_roctx_range_push
;
std
::
function
<
int
(
const
char
*
)
>
sym_roctx_range_push
;
std
::
function
<
int
()
>
sym_roctx_range_pop
;
std
::
function
<
int
()
>
sym_roctx_range_pop
;
uint64_t
range_id
;
uint64_t
range_id
=
0
;
public:
public:
marker_roctx
()
marker_roctx
()
...
...
src/eliminate_allocation.cpp
View file @
2f268bc2
...
@@ -13,13 +13,13 @@
...
@@ -13,13 +13,13 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
eliminate_allocation
::
apply
(
module
&
p
)
const
void
eliminate_allocation
::
apply
(
module
&
m
)
const
{
{
assert
(
alignment
>
0
);
assert
(
alignment
>
0
);
std
::
size_t
n
=
0
;
std
::
size_t
n
=
0
;
std
::
vector
<
std
::
pair
<
instruction_ref
,
std
::
size_t
>>
allocs
;
std
::
vector
<
std
::
pair
<
instruction_ref
,
std
::
size_t
>>
allocs
;
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
m
))
{
{
if
(
ins
->
name
()
!=
allocation_op
)
if
(
ins
->
name
()
!=
allocation_op
)
continue
;
continue
;
...
@@ -30,13 +30,13 @@ void eliminate_allocation::apply(module& p) const
...
@@ -30,13 +30,13 @@ void eliminate_allocation::apply(module& p) const
}
}
if
(
n
>
0
)
if
(
n
>
0
)
{
{
auto
mem
=
p
.
add_parameter
(
"memory"
,
shape
{
shape
::
int8_type
,
{
n
}});
auto
mem
=
m
.
add_parameter
(
"memory"
,
shape
{
shape
::
int8_type
,
{
n
}});
for
(
auto
&&
pp
:
allocs
)
for
(
auto
&&
pp
:
allocs
)
{
{
auto
ins
=
pp
.
first
;
auto
ins
=
pp
.
first
;
auto
s
=
ins
->
get_shape
();
auto
s
=
ins
->
get_shape
();
auto
offset
=
pp
.
second
;
auto
offset
=
pp
.
second
;
p
.
replace_instruction
(
m
.
replace_instruction
(
ins
,
make_op
(
"load"
,
{{
"shape"
,
to_value
(
s
)},
{
"offset"
,
offset
}}),
mem
);
ins
,
make_op
(
"load"
,
{{
"shape"
,
to_value
(
s
)},
{
"offset"
,
offset
}}),
mem
);
}
}
}
}
...
...
src/eliminate_common_subexpression.cpp
View file @
2f268bc2
...
@@ -11,7 +11,7 @@ namespace migraphx {
...
@@ -11,7 +11,7 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
class
Range
>
template
<
class
Range
>
void
cse_range
(
module
&
p
,
Range
&&
r
)
void
cse_range
(
module
&
m
,
Range
&&
r
)
{
{
std
::
unordered_multimap
<
std
::
string
,
instruction_ref
>
instructions
;
std
::
unordered_multimap
<
std
::
string
,
instruction_ref
>
instructions
;
std
::
unordered_set
<
instruction_ref
>
processed_ins
;
std
::
unordered_set
<
instruction_ref
>
processed_ins
;
...
@@ -30,24 +30,24 @@ void cse_range(module& p, Range&& r)
...
@@ -30,24 +30,24 @@ void cse_range(module& p, Range&& r)
continue
;
continue
;
if
(
*
eq
!=
*
ins
)
if
(
*
eq
!=
*
ins
)
continue
;
continue
;
p
.
replace_instruction
(
ins
,
eq
);
m
.
replace_instruction
(
ins
,
eq
);
processed_ins
.
emplace
(
ins
);
processed_ins
.
emplace
(
ins
);
std
::
vector
<
instruction_ref
>
outputs
;
std
::
vector
<
instruction_ref
>
outputs
;
std
::
copy_if
(
eq
->
outputs
().
begin
(),
std
::
copy_if
(
eq
->
outputs
().
begin
(),
eq
->
outputs
().
end
(),
eq
->
outputs
().
end
(),
std
::
back_inserter
(
outputs
),
std
::
back_inserter
(
outputs
),
[
&
](
auto
x
)
{
return
p
.
has_instruction
(
x
);
});
[
&
](
auto
x
)
{
return
m
.
has_instruction
(
x
);
});
std
::
sort
(
outputs
.
begin
(),
outputs
.
end
(),
[
&
](
auto
x
,
auto
y
)
{
std
::
sort
(
outputs
.
begin
(),
outputs
.
end
(),
[
&
](
auto
x
,
auto
y
)
{
return
std
::
distance
(
eq
,
x
)
<
std
::
distance
(
eq
,
y
);
return
std
::
distance
(
eq
,
x
)
<
std
::
distance
(
eq
,
y
);
});
});
cse_range
(
p
,
outputs
);
cse_range
(
m
,
outputs
);
}
}
instructions
.
emplace
(
ins
->
name
(),
ins
);
instructions
.
emplace
(
ins
->
name
(),
ins
);
}
}
}
}
void
eliminate_common_subexpression
::
apply
(
module
&
p
)
const
{
cse_range
(
p
,
iterator_for
(
p
));
}
void
eliminate_common_subexpression
::
apply
(
module
&
m
)
const
{
cse_range
(
m
,
iterator_for
(
m
));
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
src/eliminate_concat.cpp
View file @
2f268bc2
...
@@ -13,9 +13,9 @@
...
@@ -13,9 +13,9 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
eliminate_concat
::
apply
(
module
&
p
)
const
void
eliminate_concat
::
apply
(
module
&
m
)
const
{
{
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
m
))
{
{
// Look for the concat operator
// Look for the concat operator
if
(
ins
->
name
()
!=
concat_opt
.
name
())
if
(
ins
->
name
()
!=
concat_opt
.
name
())
...
@@ -64,22 +64,22 @@ void eliminate_concat::apply(module& p) const
...
@@ -64,22 +64,22 @@ void eliminate_concat::apply(module& p) const
std
::
sort
(
sorted_allocations
.
begin
(),
std
::
sort
(
sorted_allocations
.
begin
(),
sorted_allocations
.
end
(),
sorted_allocations
.
end
(),
[
&
](
instruction_ref
x
,
instruction_ref
y
)
{
[
&
](
instruction_ref
x
,
instruction_ref
y
)
{
return
std
::
distance
(
p
.
begin
(),
x
)
<
std
::
distance
(
p
.
begin
(),
y
);
return
std
::
distance
(
m
.
begin
(),
x
)
<
std
::
distance
(
m
.
begin
(),
y
);
});
});
// Move "super" allocation to the front
// Move "super" allocation to the front
auto
first
=
sorted_allocations
.
front
();
auto
first
=
sorted_allocations
.
front
();
auto
super
=
p
.
move_instruction
(
last
,
first
);
auto
super
=
m
.
move_instruction
(
last
,
first
);
// Replace each allocation with a load
// Replace each allocation with a load
std
::
size_t
offset
=
0
;
std
::
size_t
offset
=
0
;
for
(
auto
alloc
:
allocations
)
for
(
auto
alloc
:
allocations
)
{
{
op
::
load
op
{
alloc
->
get_shape
(),
offset
};
op
::
load
op
{
alloc
->
get_shape
(),
offset
};
p
.
replace_instruction
(
alloc
,
op
,
{
super
});
m
.
replace_instruction
(
alloc
,
op
,
{
super
});
offset
+=
alloc
->
get_shape
().
bytes
();
offset
+=
alloc
->
get_shape
().
bytes
();
}
}
std
::
vector
<
instruction_ref
>
args
=
{
super
};
std
::
vector
<
instruction_ref
>
args
=
{
super
};
std
::
copy
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
()
-
1
,
std
::
back_inserter
(
args
));
std
::
copy
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
()
-
1
,
std
::
back_inserter
(
args
));
p
.
replace_instruction
(
ins
,
migraphx
::
make_op
(
"identity"
),
args
);
m
.
replace_instruction
(
ins
,
migraphx
::
make_op
(
"identity"
),
args
);
}
}
}
}
}
}
...
...
src/eliminate_contiguous.cpp
View file @
2f268bc2
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/op/contiguous.hpp>
#include <migraphx/op/contiguous.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/par_for.hpp>
#include <utility>
#include <utility>
namespace
migraphx
{
namespace
migraphx
{
...
@@ -69,9 +70,11 @@ static bool try_compute_shape(instruction_ref ins,
...
@@ -69,9 +70,11 @@ static bool try_compute_shape(instruction_ref ins,
return
try_compute_shape
(
ins
,
inputs
,
mods
);
return
try_compute_shape
(
ins
,
inputs
,
mods
);
}
}
void
eliminate_contiguous
::
apply
(
module
&
p
)
const
void
eliminate_contiguous
::
apply
(
module
&
m
)
const
{
{
for
(
auto
ins
:
iterator_for
(
p
))
std
::
vector
<
instruction_ref
>
const_instruction
;
for
(
auto
ins
:
iterator_for
(
m
))
{
{
// return instruction should have inputs with standard shape
// return instruction should have inputs with standard shape
if
(
ins
->
name
()
==
"@return"
)
if
(
ins
->
name
()
==
"@return"
)
...
@@ -81,6 +84,7 @@ void eliminate_contiguous::apply(module& p) const
...
@@ -81,6 +84,7 @@ void eliminate_contiguous::apply(module& p) const
auto
args
=
ins
->
inputs
();
auto
args
=
ins
->
inputs
();
auto
new_args
=
args
;
auto
new_args
=
args
;
auto
mod_args
=
ins
->
module_inputs
();
auto
mod_args
=
ins
->
module_inputs
();
for
(
auto
arg
:
ins
->
inputs
())
for
(
auto
arg
:
ins
->
inputs
())
{
{
if
(
arg
->
name
()
==
op_name
)
if
(
arg
->
name
()
==
op_name
)
...
@@ -93,15 +97,25 @@ void eliminate_contiguous::apply(module& p) const
...
@@ -93,15 +97,25 @@ void eliminate_contiguous::apply(module& p) const
}
}
else
if
(
prev
->
can_eval
())
else
if
(
prev
->
can_eval
())
{
{
auto
c
=
op
::
contiguous
{};
const_instruction
.
push_back
(
arg
);
auto
r
=
c
.
compute
(
c
.
compute_shape
({
prev
->
get_shape
()}),
{
prev
->
eval
()});
auto
l
=
p
.
add_literal
(
r
.
get_shape
(),
r
.
data
());
p
.
replace_instruction
(
arg
,
l
);
}
}
}
}
}
}
}
}
// Perform evaluations in parallel
std
::
vector
<
argument
>
literals
(
const_instruction
.
size
());
par_for
(
const_instruction
.
size
(),
1
,
[
&
](
const
auto
i
)
{
auto
c
=
op
::
contiguous
{};
auto
prev
=
const_instruction
[
i
]
->
inputs
().
front
();
literals
[
i
]
=
c
.
compute
(
c
.
compute_shape
({
prev
->
get_shape
()}),
{
prev
->
eval
()});
});
for
(
size_t
i
=
0
;
i
<
const_instruction
.
size
();
i
++
)
{
auto
l
=
m
.
add_literal
(
literals
[
i
].
get_shape
(),
literals
[
i
].
data
());
m
.
replace_instruction
(
const_instruction
[
i
],
l
);
}
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/eliminate_identity.cpp
View file @
2f268bc2
...
@@ -8,21 +8,21 @@
...
@@ -8,21 +8,21 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
eliminate_identity
::
apply
(
module
&
p
)
const
void
eliminate_identity
::
apply
(
module
&
m
)
const
{
{
auto
last
=
std
::
prev
(
p
.
end
());
auto
last
=
std
::
prev
(
m
.
end
());
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
m
))
{
{
// Skip the first instruction, since we always process the previous
// Skip the first instruction, since we always process the previous
// instruction
// instruction
if
(
ins
==
p
.
begin
())
if
(
ins
==
m
.
begin
())
continue
;
continue
;
const
auto
i
=
std
::
prev
(
ins
);
const
auto
i
=
std
::
prev
(
ins
);
if
(
i
->
name
()
==
"identity"
)
if
(
i
->
name
()
==
"identity"
)
{
{
p
.
replace_instruction
(
i
,
i
->
inputs
().
front
());
m
.
replace_instruction
(
i
,
i
->
inputs
().
front
());
p
.
move_instruction
(
i
,
p
.
end
());
m
.
move_instruction
(
i
,
m
.
end
());
}
}
if
(
ins
==
last
)
if
(
ins
==
last
)
{
{
...
@@ -31,7 +31,7 @@ void eliminate_identity::apply(module& p) const
...
@@ -31,7 +31,7 @@ void eliminate_identity::apply(module& p) const
const
instruction_ref
&
identity_input
=
ins
->
inputs
().
front
();
const
instruction_ref
&
identity_input
=
ins
->
inputs
().
front
();
if
(
identity_input
->
outputs
().
size
()
==
1
)
if
(
identity_input
->
outputs
().
size
()
==
1
)
{
{
p
.
move_instruction
(
identity_input
,
i
);
m
.
move_instruction
(
identity_input
,
i
);
// since this is the last instruction, removing it only
// since this is the last instruction, removing it only
// requires changing "last" and calling remove below
// requires changing "last" and calling remove below
last
=
std
::
prev
(
last
);
last
=
std
::
prev
(
last
);
...
@@ -40,7 +40,7 @@ void eliminate_identity::apply(module& p) const
...
@@ -40,7 +40,7 @@ void eliminate_identity::apply(module& p) const
break
;
break
;
}
}
}
}
p
.
remove_instructions
(
std
::
next
(
last
),
p
.
end
());
m
.
remove_instructions
(
std
::
next
(
last
),
m
.
end
());
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/adjust_allocation.hpp
View file @
2f268bc2
...
@@ -13,7 +13,7 @@ struct adjust_allocation
...
@@ -13,7 +13,7 @@ struct adjust_allocation
{
{
allocation_model
model
;
allocation_model
model
;
std
::
string
name
()
const
{
return
"adjust_allocation"
;
}
std
::
string
name
()
const
{
return
"adjust_allocation"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/analyze_streams.hpp
View file @
2f268bc2
...
@@ -16,7 +16,7 @@ struct stream_race
...
@@ -16,7 +16,7 @@ struct stream_race
instruction_ref
before
;
instruction_ref
before
;
};
};
std
::
vector
<
stream_race
>
analyze_streams
(
const
module
&
p
,
const
stream_model
&
m
);
std
::
vector
<
stream_race
>
analyze_streams
(
const
module
&
m
,
const
stream_model
&
strm
m
);
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/include/migraphx/auto_contiguous.hpp
View file @
2f268bc2
...
@@ -13,7 +13,7 @@ struct module;
...
@@ -13,7 +13,7 @@ struct module;
struct
auto_contiguous
struct
auto_contiguous
{
{
std
::
string
name
()
const
{
return
"auto_contiguous"
;
}
std
::
string
name
()
const
{
return
"auto_contiguous"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/check_context.hpp
View file @
2f268bc2
...
@@ -33,7 +33,7 @@ struct check_context
...
@@ -33,7 +33,7 @@ struct check_context
};
};
std
::
string
name
()
const
{
return
"check_context"
;
}
std
::
string
name
()
const
{
return
"check_context"
;
}
void
apply
(
module
&
p
)
const
{
p
.
insert_instruction
(
p
.
begin
(),
op
{});
}
void
apply
(
module
&
m
)
const
{
m
.
insert_instruction
(
m
.
begin
(),
op
{});
}
};
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/compile_src.hpp
View file @
2f268bc2
...
@@ -24,6 +24,7 @@ struct src_compiler
...
@@ -24,6 +24,7 @@ struct src_compiler
std
::
string
flags
=
""
;
std
::
string
flags
=
""
;
std
::
string
output
=
""
;
std
::
string
output
=
""
;
std
::
string
launcher
=
""
;
std
::
string
launcher
=
""
;
std
::
string
out_ext
=
".o"
;
std
::
function
<
fs
::
path
(
fs
::
path
)
>
process
=
nullptr
;
std
::
function
<
fs
::
path
(
fs
::
path
)
>
process
=
nullptr
;
std
::
vector
<
char
>
compile
(
const
std
::
vector
<
src_file
>&
srcs
)
const
;
std
::
vector
<
char
>
compile
(
const
std
::
vector
<
src_file
>&
srcs
)
const
;
};
};
...
...
src/include/migraphx/eliminate_allocation.hpp
View file @
2f268bc2
...
@@ -19,7 +19,7 @@ struct eliminate_allocation
...
@@ -19,7 +19,7 @@ struct eliminate_allocation
std
::
string
allocation_op
{};
std
::
string
allocation_op
{};
std
::
size_t
alignment
=
32
;
std
::
size_t
alignment
=
32
;
std
::
string
name
()
const
{
return
"eliminate_allocation"
;
}
std
::
string
name
()
const
{
return
"eliminate_allocation"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/eliminate_common_subexpression.hpp
View file @
2f268bc2
...
@@ -16,7 +16,7 @@ struct module;
...
@@ -16,7 +16,7 @@ struct module;
struct
eliminate_common_subexpression
struct
eliminate_common_subexpression
{
{
std
::
string
name
()
const
{
return
"eliminate_common_subexpression"
;
}
std
::
string
name
()
const
{
return
"eliminate_common_subexpression"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/eliminate_concat.hpp
View file @
2f268bc2
...
@@ -18,7 +18,7 @@ struct eliminate_concat
...
@@ -18,7 +18,7 @@ struct eliminate_concat
{
{
concat_optimization
concat_opt
;
concat_optimization
concat_opt
;
std
::
string
name
()
const
{
return
"eliminate_concat"
;
}
std
::
string
name
()
const
{
return
"eliminate_concat"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
Prev
1
2
3
4
5
6
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment