Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
7f97b8ef
Unverified
Commit
7f97b8ef
authored
Oct 07, 2022
by
Ted Themistokleous
Committed by
GitHub
Oct 07, 2022
Browse files
Merge branch 'simplify_1_mul_div_ops' into divide_by_zero_check
parents
2ba401f0
d1fed367
Changes
448
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
177 additions
and
139 deletions
+177
-139
src/include/migraphx/op/transpose.hpp
src/include/migraphx/op/transpose.hpp
+1
-1
src/include/migraphx/operation.hpp
src/include/migraphx/operation.hpp
+8
-6
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+2
-0
src/include/migraphx/pad_calc.hpp
src/include/migraphx/pad_calc.hpp
+23
-25
src/include/migraphx/pass.hpp
src/include/migraphx/pass.hpp
+2
-2
src/include/migraphx/program.hpp
src/include/migraphx/program.hpp
+4
-3
src/include/migraphx/raw_data.hpp
src/include/migraphx/raw_data.hpp
+4
-4
src/include/migraphx/reflect.hpp
src/include/migraphx/reflect.hpp
+1
-1
src/include/migraphx/requires.hpp
src/include/migraphx/requires.hpp
+1
-1
src/include/migraphx/rewrite_gelu.hpp
src/include/migraphx/rewrite_gelu.hpp
+14
-10
src/include/migraphx/schedule_model.hpp
src/include/migraphx/schedule_model.hpp
+2
-2
src/include/migraphx/sqlite.hpp
src/include/migraphx/sqlite.hpp
+19
-12
src/include/migraphx/stream_model.hpp
src/include/migraphx/stream_model.hpp
+2
-2
src/include/migraphx/streamutils.hpp
src/include/migraphx/streamutils.hpp
+17
-9
src/include/migraphx/stringutils.hpp
src/include/migraphx/stringutils.hpp
+6
-6
src/include/migraphx/supported_segments.hpp
src/include/migraphx/supported_segments.hpp
+11
-9
src/include/migraphx/target.hpp
src/include/migraphx/target.hpp
+35
-32
src/include/migraphx/target_assignments.hpp
src/include/migraphx/target_assignments.hpp
+14
-3
src/include/migraphx/tensor_view.hpp
src/include/migraphx/tensor_view.hpp
+10
-10
src/include/migraphx/tune_axis.hpp
src/include/migraphx/tune_axis.hpp
+1
-1
No files found.
src/include/migraphx/op/transpose.hpp
View file @
7f97b8ef
...
...
@@ -59,7 +59,7 @@ struct transpose
}
std
::
vector
<
int64_t
>
axes
(
dims
.
size
());
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
0
);
if
(
!
std
::
is_permutation
(
axes
.
begin
(),
axes
.
end
(),
dims
.
begin
()))
if
(
not
std
::
is_permutation
(
axes
.
begin
(),
axes
.
end
(),
dims
.
begin
()))
{
MIGRAPHX_THROW
(
"TRANSPOSE: Invalid permutation"
);
}
...
...
src/include/migraphx/operation.hpp
View file @
7f97b8ef
...
...
@@ -68,8 +68,10 @@ struct operation
*
* @param ctx This is the context created by the `target` during compilation. Implementations
* can use the target's `context` class rather than the `context` interface class.
* @param output This is the output shape. It is equivalent to running `compute_shape` with each
* `shape` of the `argument`.
* @param output Equivalent to running `compute_shape` with each `shape` of the `argument`.
* For a fixed shape, the returned argument will have the same shape as `output`.
* For a dynamic shape, the returned `argument` will be a fixed shape within the bounds
* set in the dynamic shape `output`.
* @param input This is the `argument` result from the previous instruction's computation.
* @return Return an `argument` of the result computation. The `shape` of `argument` should be
* the same the `output` shape.
...
...
@@ -137,7 +139,7 @@ auto compute_shape_op(rank<2>, const T& x, const std::vector<shape>& inputs)
->
decltype
(
x
.
normalize_compute_shape
(
inputs
))
{
dependent_type
<
operation
,
T
>
y
=
x
;
normalize_attributes
(
y
,
inputs
[
0
].
lens
());
normalize_attributes
(
y
,
inputs
[
0
].
max_
lens
());
return
any_cast
<
T
>
(
y
).
normalize_compute_shape
(
inputs
);
}
...
...
@@ -1064,7 +1066,7 @@ struct operation
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
))
{
...
...
@@ -1235,7 +1237,7 @@ struct operation
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
if
(
not
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
}
...
...
@@ -1274,7 +1276,7 @@ inline const ValueType& any_cast(const operation& x)
}
#endif
inline
bool
operator
!=
(
const
operation
&
x
,
const
operation
&
y
)
{
return
!
(
x
==
y
);
}
inline
bool
operator
!=
(
const
operation
&
x
,
const
operation
&
y
)
{
return
not
(
x
==
y
);
}
inline
value
compile
(
operation
&
op
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
shape
>&
input
)
...
...
src/include/migraphx/operators.hpp
View file @
7f97b8ef
...
...
@@ -57,6 +57,7 @@
#include <migraphx/op/exp.hpp>
#include <migraphx/op/flatten.hpp>
#include <migraphx/op/floor.hpp>
#include <migraphx/op/fmod.hpp>
#include <migraphx/op/gather.hpp>
#include <migraphx/op/gathernd.hpp>
#include <migraphx/op/get_tuple_elem.hpp>
...
...
@@ -79,6 +80,7 @@
#include <migraphx/op/lstm.hpp>
#include <migraphx/op/max.hpp>
#include <migraphx/op/min.hpp>
#include <migraphx/op/mod.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/multibroadcast.hpp>
#include <migraphx/op/neg.hpp>
...
...
src/include/migraphx/pad_calc.hpp
View file @
7f97b8ef
...
...
@@ -24,38 +24,36 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_PAD_CALC_HPP
#define MIGRAPHX_GUARD_OPERATORS_PAD_CALC_HPP
#include <
utility
>
#include <
migraphx/config.hpp
>
#include <cstdint>
#include <vector>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
void
calculate_padding
(
int64_t
idx
,
void
calculate_padding
(
int64_t
idx
,
std
::
vector
<
int64_t
>&
pads
,
int64_t
input_dim
,
int64_t
stride
,
int64_t
dilation
,
int64_t
weight_dim
,
bool
is_same_upper
=
true
)
{
int64_t
output_dim
=
(
input_dim
+
stride
-
1
)
/
stride
;
// round up result
int64_t
new_weight_dim
=
weight_dim
+
(
weight_dim
-
1
)
*
(
dilation
-
1
);
int64_t
pad
=
std
::
max
(
static_cast
<
int64_t
>
(
0
),
(
output_dim
-
1
)
*
stride
+
new_weight_dim
-
input_dim
);
auto
pad_ndims
=
pads
.
size
()
/
2
;
bool
is_same_upper
=
true
);
if
(
is_same_upper
)
{
pads
[
idx
]
=
pad
/
2
;
pads
[
idx
+
pad_ndims
]
=
pad
-
pad
/
2
;
}
else
{
pads
[
idx
+
pad_ndims
]
=
pad
/
2
;
pads
[
idx
]
=
pad
-
pad
/
2
;
}
}
/*!
* Calculate the padding for auto_padding. Used for dynamic shapes
* where the padding calculation must be done at evaluation time.
* \param tensor_lens input tensor image shape
* \param k_lens weights kernel shape
* \param strides strides for the kernel
* \param dilations dilations for the kernel
* \param use_upper put odd padding on upper or lower side
* \return padding in the form of {x0_begin, x1_begin, ... x0_end , x1_end, ...}
*/
std
::
vector
<
std
::
size_t
>
calc_dyn_auto_pad
(
std
::
vector
<
std
::
size_t
>
tensor_lens
,
std
::
vector
<
std
::
size_t
>
k_lens
,
std
::
vector
<
std
::
size_t
>
strides
,
std
::
vector
<
std
::
size_t
>
dilations
,
bool
use_upper
=
true
);
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
...
...
src/include/migraphx/pass.hpp
View file @
7f97b8ef
...
...
@@ -238,7 +238,7 @@ struct pass
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
))
{
...
...
@@ -292,7 +292,7 @@ struct pass
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
if
(
not
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
}
...
...
src/include/migraphx/program.hpp
View file @
7f97b8ef
...
...
@@ -37,6 +37,7 @@
#include <migraphx/assignment_options.hpp>
#include <migraphx/env.hpp>
#include <migraphx/config.hpp>
#include <migraphx/execution_environment.hpp>
#include <algorithm>
#include <iostream>
...
...
@@ -76,8 +77,8 @@ struct program
std
::
unordered_map
<
std
::
string
,
shape
>
get_parameter_shapes
()
const
;
std
::
vector
<
argument
>
eval
(
parameter_map
params
)
const
;
std
::
vector
<
argument
>
eval
(
parameter_map
params
,
execution_environment
exec_env
=
execution_environment
{})
const
;
std
::
size_t
size
()
const
;
std
::
vector
<
shape
>
get_output_shapes
()
const
;
...
...
@@ -124,7 +125,7 @@ struct program
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
program
&
p
);
friend
bool
operator
==
(
const
program
&
x
,
const
program
&
y
);
friend
bool
operator
!=
(
const
program
&
x
,
const
program
&
y
)
{
return
!
(
x
==
y
);
}
friend
bool
operator
!=
(
const
program
&
x
,
const
program
&
y
)
{
return
not
(
x
==
y
);
}
// module related api
module
*
create_module
(
const
std
::
string
&
name
);
...
...
src/include/migraphx/raw_data.hpp
View file @
7f97b8ef
...
...
@@ -147,7 +147,7 @@ struct raw_data : raw_data_base
template
<
class
T
>
bool
matches
()
const
{
return
is_data_ptr
<
T
>
{}
||
return
is_data_ptr
<
T
>
{}
or
self
->
get_shape
().
type
()
==
migraphx
::
shape
::
get_type
<
get_data_type
<
T
>>
{};
}
...
...
@@ -232,7 +232,7 @@ auto visit_all(T&& x, Ts&&... xs)
{
auto
&&
s
=
x
.
get_shape
();
std
::
initializer_list
<
shape
::
type_t
>
types
=
{
xs
.
get_shape
().
type
()...};
if
(
!
std
::
all_of
(
types
.
begin
(),
types
.
end
(),
[
&
](
shape
::
type_t
t
)
{
return
t
==
s
.
type
();
}))
if
(
not
std
::
all_of
(
types
.
begin
(),
types
.
end
(),
[
&
](
shape
::
type_t
t
)
{
return
t
==
s
.
type
();
}))
MIGRAPHX_THROW
(
"Types must be the same"
);
return
[
&
](
auto
...
vs
)
{
detail
::
visit_all_pack
(
s
,
vs
...)(
x
,
xs
...);
};
}
...
...
@@ -241,7 +241,7 @@ template <class T>
auto
visit_all
(
const
std
::
vector
<
T
>&
x
)
{
auto
&&
s
=
x
.
front
().
get_shape
();
if
(
!
std
::
all_of
(
if
(
not
std
::
all_of
(
x
.
begin
(),
x
.
end
(),
[
&
](
const
T
&
y
)
{
return
y
.
get_shape
().
type
()
==
s
.
type
();
}))
MIGRAPHX_THROW
(
"Types must be the same"
);
return
[
&
](
auto
v
)
{
...
...
@@ -281,7 +281,7 @@ template <class T,
std
::
is_base_of
<
raw_data_base
,
U
>
{})
>
bool
operator
!=
(
const
T
&
x
,
const
U
&
y
)
{
return
!
(
x
==
y
);
return
not
(
x
==
y
);
}
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/reflect.hpp
View file @
7f97b8ef
...
...
@@ -129,7 +129,7 @@ template <class T>
struct
reflect_equality
{
friend
bool
operator
==
(
const
T
&
x
,
const
T
&
y
)
{
return
reflect_tie
(
x
)
==
reflect_tie
(
y
);
}
friend
bool
operator
!=
(
const
T
&
x
,
const
T
&
y
)
{
return
!
(
x
==
y
);
}
friend
bool
operator
!=
(
const
T
&
x
,
const
T
&
y
)
{
return
not
(
x
==
y
);
}
};
template
<
class
T
>
...
...
src/include/migraphx/requires.hpp
View file @
7f97b8ef
...
...
@@ -31,7 +31,7 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
bool
...
Bs
>
struct
and_
:
std
::
is_same
<
and_
<
Bs
...
>
,
and_
<
(
Bs
||
true
)...
>>
// NOLINT
struct
and_
:
std
::
is_same
<
and_
<
Bs
...
>
,
and_
<
(
Bs
or
true
)...
>>
// NOLINT
{
};
...
...
src/
targets/gpu/
include/migraphx/
gpu/device/convert
.hpp
→
src/include/migraphx/
rewrite_gelu
.hpp
View file @
7f97b8ef
...
...
@@ -21,23 +21,27 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_REWRITE_GELU_HPP
#define MIGRAPHX_GUARD_RTGLIB_REWRITE_GELU_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_CONVERT_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_CONVERT_HPP
#include <migraphx/argument.hpp>
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
device
{
void
convert
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
);
struct
module
;
/**
* Rewrite gelu standard formula as the sigmoid approximation formula
*/
struct
rewrite_gelu
{
std
::
string
name
()
const
{
return
"rewrite_gelu"
;
}
void
apply
(
module
&
m
)
const
;
};
}
// namespace device
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
...
...
src/include/migraphx/schedule_model.hpp
View file @
7f97b8ef
...
...
@@ -208,7 +208,7 @@ struct schedule_model
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
))
{
...
...
@@ -274,7 +274,7 @@ struct schedule_model
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
if
(
not
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
}
...
...
src/
targets/gpu/
include/migraphx/
gpu/device/equal
.hpp
→
src/include/migraphx/
sqlite
.hpp
View file @
7f97b8ef
...
...
@@ -21,24 +21,31 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_SQLITE_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_SQLITE_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_EQUAL_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_EQUAL_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
#include <migraphx/filesystem.hpp>
#include <memory>
#include <unordered_map>
#include <vector>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
device
{
void
equal
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
);
struct
sqlite_impl
;
struct
sqlite
{
sqlite
()
=
default
;
static
sqlite
read
(
const
fs
::
path
&
p
);
static
sqlite
write
(
const
fs
::
path
&
p
);
std
::
vector
<
std
::
unordered_map
<
std
::
string
,
std
::
string
>>
execute
(
const
std
::
string
&
s
);
private:
std
::
shared_ptr
<
sqlite_impl
>
impl
;
};
}
// namespace device
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
#endif // MIGRAPHX_GUARD_MIGRAPHX_SQLITE_HPP
src/include/migraphx/stream_model.hpp
View file @
7f97b8ef
...
...
@@ -216,7 +216,7 @@ struct stream_model
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
))
{
...
...
@@ -288,7 +288,7 @@ struct stream_model
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
if
(
not
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
}
...
...
src/include/migraphx/streamutils.hpp
View file @
7f97b8ef
...
...
@@ -28,6 +28,7 @@
#include <algorithm>
#include <migraphx/rank.hpp>
#include <migraphx/config.hpp>
#include <vector>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -41,7 +42,7 @@ struct stream_range_container
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
stream_range_container
&
sr
)
{
assert
(
sr
.
r
!=
nullptr
);
if
(
!
sr
.
r
->
empty
())
if
(
not
sr
.
r
->
empty
())
{
os
<<
sr
.
r
->
front
();
std
::
for_each
(
...
...
@@ -59,28 +60,35 @@ inline stream_range_container<Range> stream_range(const Range& r)
namespace
detail
{
inline
void
stream_write_value_impl
(
rank
<
2
>
,
std
::
ostream
&
os
,
const
std
::
string
&
x
)
{
os
<<
x
;
}
template
<
class
T
>
auto
stream_write_value_impl
(
rank
<
1
>
,
std
::
ostream
&
os
,
const
T
&
x
)
->
decltype
(
os
<<
x
,
void
())
{
os
<<
x
;
}
template
<
class
Range
>
auto
stream_write_value_impl
(
rank
<
1
>
,
std
::
ostream
&
os
,
const
Range
&
r
)
->
decltype
(
r
.
begin
(),
r
.
end
(),
void
())
template
<
class
T
>
void
stream_write_value_impl
(
rank
<
1
>
,
std
::
ostream
&
os
,
const
std
::
vector
<
T
>&
r
)
{
os
<<
"{"
;
os
<<
stream_range
(
r
);
os
<<
"}"
;
}
template
<
class
T
>
void
stream_write_value_impl
(
rank
<
0
>
,
std
::
ostream
&
os
,
const
T
&
x
)
template
<
class
Range
>
auto
stream_write_value_impl
(
rank
<
0
>
,
std
::
ostream
&
os
,
const
Range
&
r
)
->
decltype
(
r
.
begin
(),
r
.
end
(),
void
())
{
os
<<
x
;
os
<<
"{"
;
os
<<
stream_range
(
r
);
os
<<
"}"
;
}
}
// namespace detail
template
<
class
T
>
void
stream_write_value
(
std
::
ostream
&
os
,
const
T
&
x
)
{
detail
::
stream_write_value_impl
(
rank
<
2
>
{},
os
,
x
);
detail
::
stream_write_value_impl
(
rank
<
1
>
{},
os
,
x
);
}
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/stringutils.hpp
View file @
7f97b8ef
...
...
@@ -174,27 +174,27 @@ inline std::string interpolate_string(const std::string& input,
}
template
<
class
Iterator
>
inline
std
::
string
to_string_range
(
Iterator
start
,
Iterator
last
)
inline
std
::
string
to_string_range
(
Iterator
start
,
Iterator
last
,
const
char
*
delim
=
", "
)
{
std
::
stringstream
ss
;
if
(
start
!=
last
)
{
ss
<<
*
start
;
std
::
for_each
(
std
::
next
(
start
),
last
,
[
&
](
auto
&&
x
)
{
ss
<<
", "
<<
x
;
});
std
::
for_each
(
std
::
next
(
start
),
last
,
[
&
](
auto
&&
x
)
{
ss
<<
delim
<<
x
;
});
}
return
ss
.
str
();
}
template
<
class
Range
>
inline
std
::
string
to_string_range
(
const
Range
&
r
)
inline
std
::
string
to_string_range
(
const
Range
&
r
,
const
char
*
delim
=
", "
)
{
return
to_string_range
(
r
.
begin
(),
r
.
end
());
return
to_string_range
(
r
.
begin
(),
r
.
end
()
,
delim
);
}
template
<
class
T
>
inline
std
::
string
to_string_range
(
const
std
::
initializer_list
<
T
>&
r
)
inline
std
::
string
to_string_range
(
const
std
::
initializer_list
<
T
>&
r
,
const
char
*
delim
=
", "
)
{
return
to_string_range
(
r
.
begin
(),
r
.
end
());
return
to_string_range
(
r
.
begin
(),
r
.
end
()
,
delim
);
}
template
<
class
T
>
...
...
src/
targets/gpu/
include/migraphx/
gpu/greater
.hpp
→
src/include/migraphx/
supported_segments
.hpp
View file @
7f97b8ef
...
...
@@ -21,22 +21,24 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_
RTGLIB_GREATER
_HPP
#define MIGRAPHX_GUARD_
RTGLIB_GREATER
_HPP
#ifndef MIGRAPHX_GUARD_
MIGRAPHX_SUPPORTED_SEGMENTS
_HPP
#define MIGRAPHX_GUARD_
MIGRAPHX_SUPPORTED_SEGMENTS
_HPP
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/greater.hpp>
#include <unordered_set>
#include <migraphx/instruction_ref.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
struct
hip_greater
:
binary_device
<
hip_greater
,
device
::
greater
>
struct
supported_segment
{
std
::
unordered_set
<
instruction_ref
>
instructions
;
float
metric
;
};
}
// namespace gpu
using
supported_segments
=
std
::
vector
<
supported_segment
>
;
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
#endif // MIGRAPHX_GUARD_MIGRAPHX_SUPPORTED_SEGMENTS_HPP
src/include/migraphx/target.hpp
View file @
7f97b8ef
...
...
@@ -37,8 +37,10 @@
#include <migraphx/compile_options.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/rank.hpp>
#include <migraphx/module_ref.hpp>
#include <migraphx/support_metric.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/supported_segments.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -64,12 +66,12 @@ struct target
*/
context
get_context
()
const
;
/**
* @brief
Check how well an
instruction
is
supported on a target
with the given metric
* @param
ins Instruction
to check
if it's
supported
* @param metric Used to define how the
return value
should be
interpret
ed
* @return
T
he
value based on the chosen metric. Negative numbers mean unsupported
* @brief
Get the ranges of
instruction
s that are
supported on a target
* @param
module Module
to check
for
supported
instructions
* @param metric Used to define how the
quality of the support
should be
measur
ed
* @return
t
he
supported segments of the graph
*/
float
is_supported
(
T
&
,
i
nst
ruction
_ref
ins
,
support_metric
m
)
const
;
supported_segments
target_
is_supported
(
T
&
,
co
nst
_module
_ref
mod
,
support_metric
m
etric
)
const
;
/**
* @brief copy an argument to the current target.
*
...
...
@@ -115,9 +117,9 @@ argument copy_from_target(T&, const argument& arg)
}
template
<
class
T
>
float
target_
is
_supported
(
T
&
,
i
nst
ruction
_ref
,
support_metric
)
supported_segments
target_
find
_supported
(
T
&
,
co
nst
_module
_ref
,
support_metric
)
{
return
0
;
return
{}
;
}
#ifdef TYPE_ERASED_DECLARATION
...
...
@@ -132,7 +134,7 @@ struct target
//
context
get_context
()
const
;
// (optional)
float
is
_supported
(
i
nst
ruction
_ref
ins
,
support_metric
m
)
const
;
supported_segments
find
_supported
(
co
nst
_module
_ref
mod
,
support_metric
m
)
const
;
// (optional)
argument
copy_to
(
const
argument
&
input
)
const
;
// (optional)
...
...
@@ -224,10 +226,10 @@ struct target
return
(
*
this
).
private_detail_te_get_handle
().
get_context
();
}
float
is
_supported
(
i
nst
ruction
_ref
ins
,
support_metric
m
)
const
supported_segments
find
_supported
(
co
nst
_module
_ref
mod
,
support_metric
m
)
const
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
return
(
*
this
).
private_detail_te_get_handle
().
is
_supported
(
ins
,
m
);
return
(
*
this
).
private_detail_te_get_handle
().
find
_supported
(
mod
,
m
);
}
argument
copy_to
(
const
argument
&
input
)
const
...
...
@@ -265,29 +267,29 @@ struct target
virtual
std
::
vector
<
pass
>
get_passes
(
context
&
ctx
,
const
compile_options
&
options
)
const
=
0
;
virtual
context
get_context
()
const
=
0
;
virtual
float
is
_supported
(
i
nst
ruction
_ref
ins
,
support_metric
m
)
const
=
0
;
virtual
supported_segments
find
_supported
(
co
nst
_module
_ref
mod
,
support_metric
m
)
const
=
0
;
virtual
argument
copy_to
(
const
argument
&
input
)
const
=
0
;
virtual
argument
copy_from
(
const
argument
&
input
)
const
=
0
;
virtual
argument
allocate
(
const
shape
&
s
)
const
=
0
;
};
template
<
class
T
>
static
auto
private_detail_te_default_
is
_supported
(
char
,
static
auto
private_detail_te_default_
find
_supported
(
char
,
T
&&
private_detail_te_self
,
instruction
_ref
ins
,
const_module
_ref
mod
,
support_metric
m
)
->
decltype
(
private_detail_te_self
.
is
_supported
(
ins
,
m
))
->
decltype
(
private_detail_te_self
.
find
_supported
(
mod
,
m
))
{
return
private_detail_te_self
.
is
_supported
(
ins
,
m
);
return
private_detail_te_self
.
find
_supported
(
mod
,
m
);
}
template
<
class
T
>
static
float
private_detail_te_default_
is
_supported
(
float
,
static
supported_segments
private_detail_te_default_
find
_supported
(
float
,
T
&&
private_detail_te_self
,
instruction
_ref
ins
,
const_module
_ref
mod
,
support_metric
m
)
{
return
target_
is
_supported
(
private_detail_te_self
,
ins
,
m
);
return
target_
find
_supported
(
private_detail_te_self
,
mod
,
m
);
}
template
<
class
T
>
...
...
@@ -349,7 +351,7 @@ struct target
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
))
{
...
...
@@ -372,10 +374,11 @@ struct target
context
get_context
()
const
override
{
return
private_detail_te_value
.
get_context
();
}
float
is
_supported
(
i
nst
ruction
_ref
ins
,
support_metric
m
)
const
override
supported_segments
find
_supported
(
co
nst
_module
_ref
mod
,
support_metric
m
)
const
override
{
return
private_detail_te_default_is_supported
(
char
(
0
),
private_detail_te_value
,
ins
,
m
);
return
private_detail_te_default_find_supported
(
char
(
0
),
private_detail_te_value
,
mod
,
m
);
}
argument
copy_to
(
const
argument
&
input
)
const
override
...
...
@@ -423,7 +426,7 @@ struct target
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
if
(
not
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
}
...
...
src/include/migraphx/target_assignments.hpp
View file @
7f97b8ef
...
...
@@ -25,6 +25,7 @@
#define MIGRAPHX_GUARD_MIGRAPHX_ASSIGNMENT_HPP
#include <unordered_map>
#include <string>
#include <migraphx/instruction_ref.hpp>
...
...
@@ -33,10 +34,20 @@ inline namespace MIGRAPHX_INLINE_NS {
struct
target_assignments
{
void
add_assignment
(
instruction_ref
ins
,
const
std
::
string
&
target
);
using
iterator
=
std
::
unordered_map
<
instruction_ref
,
std
::
string
>::
const_iterator
;
using
value_type
=
std
::
pair
<
instruction_ref
,
std
::
string
>
;
auto
begin
()
const
{
return
assignments
.
cbegin
();
}
auto
end
()
const
{
return
assignments
.
cend
();
}
auto
size
()
const
{
return
assignments
.
size
();
}
auto
&
at
(
instruction_ref
ins
)
const
{
return
assignments
.
at
(
ins
);
}
auto
insert
(
iterator
it
,
const
std
::
pair
<
instruction_ref
,
std
::
string
>&
assignment
)
{
return
assignments
.
insert
(
it
,
assignment
);
}
auto
find
(
instruction_ref
ins
)
const
{
return
assignments
.
find
(
ins
);
}
auto
begin
()
const
{
return
assignments
.
begin
();
}
auto
end
()
const
{
return
assignments
.
end
();
}
private:
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
assignments
;
...
...
src/include/migraphx/tensor_view.hpp
View file @
7f97b8ef
...
...
@@ -67,7 +67,7 @@ struct tensor_view
const
shape
&
get_shape
()
const
{
return
this
->
m_shape
;
}
bool
empty
()
const
{
return
m_data
==
nullptr
||
m_shape
.
lens
().
empty
();
}
bool
empty
()
const
{
return
m_data
==
nullptr
or
m_shape
.
lens
().
empty
();
}
std
::
size_t
size
()
const
{
return
m_shape
.
elements
();
}
...
...
@@ -109,37 +109,37 @@ struct tensor_view
T
&
operator
[](
std
::
size_t
i
)
{
assert
(
!
this
->
empty
()
&&
i
<
this
->
size
());
assert
(
not
this
->
empty
()
&&
i
<
this
->
size
());
return
m_data
[
m_shape
.
index
(
i
)];
}
const
T
&
operator
[](
std
::
size_t
i
)
const
{
assert
(
!
this
->
empty
()
&&
i
<
this
->
size
());
assert
(
not
this
->
empty
()
&&
i
<
this
->
size
());
return
m_data
[
m_shape
.
index
(
i
)];
}
T
&
front
()
{
assert
(
!
this
->
empty
());
assert
(
not
this
->
empty
());
return
m_data
[
0
];
}
const
T
&
front
()
const
{
assert
(
!
this
->
empty
());
assert
(
not
this
->
empty
());
return
m_data
[
0
];
}
T
&
back
()
{
assert
(
!
this
->
empty
());
assert
(
not
this
->
empty
());
return
m_data
[
m_shape
.
index
(
this
->
size
()
-
1
)];
}
const
T
&
back
()
const
{
assert
(
!
this
->
empty
());
assert
(
not
this
->
empty
());
return
m_data
[
m_shape
.
index
(
this
->
size
()
-
1
)];
}
...
...
@@ -159,7 +159,7 @@ struct tensor_view
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
tensor_view
<
T
>&
x
)
{
if
(
!
x
.
empty
())
if
(
not
x
.
empty
())
{
os
<<
as_number
(
x
.
front
());
for
(
std
::
size_t
i
=
1
;
i
<
x
.
m_shape
.
elements
();
i
++
)
...
...
@@ -182,7 +182,7 @@ bool operator==(const tensor_view<T>& x, const tensor_view<U>& y)
{
for
(
std
::
size_t
i
=
0
;
i
<
x
.
get_shape
().
elements
();
i
++
)
{
if
(
!
float_equal
(
x
[
i
],
y
[
i
]))
if
(
not
float_equal
(
x
[
i
],
y
[
i
]))
return
false
;
}
return
true
;
...
...
@@ -193,7 +193,7 @@ bool operator==(const tensor_view<T>& x, const tensor_view<U>& y)
template
<
class
T
,
class
U
>
bool
operator
!=
(
const
tensor_view
<
T
>&
x
,
const
tensor_view
<
U
>&
y
)
{
return
!
(
x
==
y
);
return
not
(
x
==
y
);
}
template
<
class
T
>
...
...
src/include/migraphx/tune_axis.hpp
View file @
7f97b8ef
...
...
@@ -34,7 +34,7 @@ inline namespace MIGRAPHX_INLINE_NS {
inline
int
tune_axis
(
const
int
n_dim
,
const
int
axis
,
const
std
::
string
&
op_name
=
"OPERATOR"
)
{
if
(
axis
>=
n_dim
||
std
::
abs
(
axis
)
>
n_dim
)
if
(
axis
>=
n_dim
or
std
::
abs
(
axis
)
>
n_dim
)
{
MIGRAPHX_THROW
(
to_upper
(
op_name
)
+
": axis is out of range."
);
}
...
...
Prev
1
2
3
4
5
6
7
8
…
23
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment