Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
40fbef9b
Unverified
Commit
40fbef9b
authored
Aug 05, 2023
by
Ted Themistokleous
Committed by
GitHub
Aug 05, 2023
Browse files
Merge branch 'develop' into threaded_nms
parents
d164b151
aeb9f78c
Changes
440
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
197 additions
and
62 deletions
+197
-62
src/include/migraphx/rewrite_pooling.hpp
src/include/migraphx/rewrite_pooling.hpp
+1
-1
src/include/migraphx/rewrite_quantization.hpp
src/include/migraphx/rewrite_quantization.hpp
+1
-1
src/include/migraphx/rewrite_rnn.hpp
src/include/migraphx/rewrite_rnn.hpp
+1
-1
src/include/migraphx/schedule.hpp
src/include/migraphx/schedule.hpp
+1
-1
src/include/migraphx/schedule_model.hpp
src/include/migraphx/schedule_model.hpp
+3
-3
src/include/migraphx/shape.hpp
src/include/migraphx/shape.hpp
+59
-23
src/include/migraphx/simplify_algebra.hpp
src/include/migraphx/simplify_algebra.hpp
+1
-1
src/include/migraphx/simplify_qdq.hpp
src/include/migraphx/simplify_qdq.hpp
+1
-1
src/include/migraphx/simplify_reshapes.hpp
src/include/migraphx/simplify_reshapes.hpp
+1
-1
src/include/migraphx/source_location.hpp
src/include/migraphx/source_location.hpp
+73
-0
src/include/migraphx/split_single_dyn_dim.hpp
src/include/migraphx/split_single_dyn_dim.hpp
+1
-1
src/include/migraphx/sqlite.hpp
src/include/migraphx/sqlite.hpp
+1
-1
src/include/migraphx/stream_model.hpp
src/include/migraphx/stream_model.hpp
+3
-3
src/include/migraphx/target.hpp
src/include/migraphx/target.hpp
+8
-3
src/include/migraphx/tf.hpp
src/include/migraphx/tf.hpp
+5
-1
src/include/migraphx/tmp_dir.hpp
src/include/migraphx/tmp_dir.hpp
+1
-1
src/include/migraphx/value.hpp
src/include/migraphx/value.hpp
+24
-10
src/include/migraphx/verify.hpp
src/include/migraphx/verify.hpp
+2
-0
src/include/migraphx/verify_args.hpp
src/include/migraphx/verify_args.hpp
+1
-0
src/instruction.cpp
src/instruction.cpp
+9
-9
No files found.
src/include/migraphx/rewrite_pooling.hpp
View file @
40fbef9b
...
...
@@ -35,7 +35,7 @@ struct module;
/**
* Rewrite pooling to reduce_mean
*/
struct
rewrite_pooling
struct
MIGRAPHX_EXPORT
rewrite_pooling
{
std
::
string
name
()
const
{
return
"rewrite_pooling"
;
}
void
apply
(
module
&
m
)
const
;
...
...
src/include/migraphx/rewrite_quantization.hpp
View file @
40fbef9b
...
...
@@ -35,7 +35,7 @@ struct module;
/**
* Rewrite quantization ops to equivalent operators
*/
struct
rewrite_quantization
struct
MIGRAPHX_EXPORT
rewrite_quantization
{
std
::
string
name
()
const
{
return
"rewrite_quantization"
;
}
void
apply
(
module
&
m
)
const
;
...
...
src/include/migraphx/rewrite_rnn.hpp
View file @
40fbef9b
...
...
@@ -39,7 +39,7 @@ struct module;
/**
* Rewrite rnn to gemm and add.
*/
struct
rewrite_rnn
struct
MIGRAPHX_EXPORT
rewrite_rnn
{
std
::
string
name
()
const
{
return
"rewrite_rnn"
;
}
void
apply
(
module
&
m
)
const
;
...
...
src/include/migraphx/schedule.hpp
View file @
40fbef9b
...
...
@@ -37,7 +37,7 @@ struct module;
/**
* Schedule instructions for concurrent execution
*/
struct
schedule
struct
MIGRAPHX_EXPORT
schedule
{
schedule_model
model
{};
bool
enable
=
true
;
...
...
src/include/migraphx/schedule_model.hpp
View file @
40fbef9b
...
...
@@ -63,7 +63,7 @@ struct schedule_model
#ifdef TYPE_ERASED_DECLARATION
// Type-erased interface for:
struct
schedule_model
struct
MIGRAPHX_EXPORT
schedule_model
{
//
std
::
size_t
concurrency
()
const
;
...
...
@@ -99,7 +99,7 @@ struct schedule_model
{
using
std
::
swap
;
auto
*
derived
=
this
->
any_cast
<
PrivateDetailTypeErasedT
>
();
if
(
derived
and
private_detail_te_handle_mem_var
.
u
nique
()
)
if
(
derived
and
private_detail_te_handle_mem_var
.
u
se_count
()
==
1
)
{
*
derived
=
std
::
forward
<
PrivateDetailTypeErasedT
>
(
value
);
}
...
...
@@ -274,7 +274,7 @@ struct schedule_model
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
not
private_detail_te_handle_mem_var
.
u
nique
()
)
if
(
private_detail_te_handle_mem_var
.
u
se_count
()
>
1
)
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
}
...
...
src/include/migraphx/shape.hpp
View file @
40fbef9b
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -43,7 +43,7 @@ inline namespace MIGRAPHX_INLINE_NS {
struct
value
;
struct
shape_impl
;
struct
shape
struct
MIGRAPHX_EXPORT
shape
{
// Add new types here
...
...
@@ -85,7 +85,7 @@ struct shape
{
};
struct
dynamic_dimension
struct
MIGRAPHX_EXPORT
dynamic_dimension
{
std
::
size_t
min
=
0
;
std
::
size_t
max
=
0
;
...
...
@@ -100,22 +100,28 @@ struct shape
bool
is_fixed
()
const
;
bool
has_optimal
()
const
;
friend
bool
operator
==
(
const
dynamic_dimension
&
x
,
const
dynamic_dimension
&
y
);
friend
bool
operator
!=
(
const
dynamic_dimension
&
x
,
const
dynamic_dimension
&
y
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
dynamic_dimension
&
x
);
MIGRAPHX_EXPORT
friend
bool
operator
==
(
const
dynamic_dimension
&
x
,
const
dynamic_dimension
&
y
);
MIGRAPHX_EXPORT
friend
bool
operator
!=
(
const
dynamic_dimension
&
x
,
const
dynamic_dimension
&
y
);
MIGRAPHX_EXPORT
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
dynamic_dimension
&
x
);
// compare to fixed std::size_t dimension
friend
bool
operator
==
(
const
dynamic_dimension
&
x
,
const
std
::
size_t
&
y
);
friend
bool
operator
==
(
const
std
::
size_t
&
x
,
const
dynamic_dimension
&
y
);
friend
bool
operator
!=
(
const
dynamic_dimension
&
x
,
const
std
::
size_t
&
y
);
friend
bool
operator
!=
(
const
std
::
size_t
&
x
,
const
dynamic_dimension
&
y
);
MIGRAPHX_EXPORT
friend
bool
operator
==
(
const
dynamic_dimension
&
x
,
const
std
::
size_t
&
y
);
MIGRAPHX_EXPORT
friend
bool
operator
==
(
const
std
::
size_t
&
x
,
const
dynamic_dimension
&
y
);
MIGRAPHX_EXPORT
friend
bool
operator
!=
(
const
dynamic_dimension
&
x
,
const
std
::
size_t
&
y
);
MIGRAPHX_EXPORT
friend
bool
operator
!=
(
const
std
::
size_t
&
x
,
const
dynamic_dimension
&
y
);
// add and subtract fixed std::size_t dimension
dynamic_dimension
&
operator
+=
(
const
std
::
size_t
&
x
);
dynamic_dimension
&
operator
-=
(
const
std
::
size_t
&
x
);
friend
dynamic_dimension
operator
+
(
const
dynamic_dimension
&
x
,
const
std
::
size_t
&
y
);
friend
dynamic_dimension
operator
+
(
const
std
::
size_t
&
x
,
const
dynamic_dimension
&
y
);
friend
dynamic_dimension
operator
-
(
const
dynamic_dimension
&
x
,
const
std
::
size_t
&
y
);
MIGRAPHX_EXPORT
friend
dynamic_dimension
operator
+
(
const
dynamic_dimension
&
x
,
const
std
::
size_t
&
y
);
MIGRAPHX_EXPORT
friend
dynamic_dimension
operator
+
(
const
std
::
size_t
&
x
,
const
dynamic_dimension
&
y
);
MIGRAPHX_EXPORT
friend
dynamic_dimension
operator
-
(
const
dynamic_dimension
&
x
,
const
std
::
size_t
&
y
);
};
static
const
std
::
vector
<
type_t
>&
types
();
...
...
@@ -156,14 +162,34 @@ struct shape
shape
(
const
std
::
vector
<
shape
>&
subs
);
/**
* Creates an output shape with dimensions equal to the input lengths and strides determined
* by the permutation argument such that find_permutation() of the output shape returns the
* inputted permuation.
*
* 2D example:
* parameters:
* l = [2, 3], perm = [1, 0]
* therefore:
* "original" shape = {lens = [3, 2], strides = [2, 1]}
* output_shape = {lens = [2, 3], strides = [1, 2]
*
* 3D example:
* parameters:
* l = [2, 3, 4], perm = [1, 2, 0]
* therefore:
* "original" shape = {lens = [3, 4, 2], strides = [8, 2, 1]}
* output_shape = {lens = [2, 3, 4], strides = [1, 8, 2]}
*/
static
shape
from_permutation
(
type_t
t
,
const
std
::
vector
<
std
::
size_t
>&
l
,
const
std
::
vector
<
int64_t
>&
perm
);
type_t
type
()
const
;
const
std
::
vector
<
std
::
size_t
>&
lens
()
const
;
const
std
::
vector
<
std
::
size_t
>&
strides
()
const
;
/*!
* The number of dimensions in the shape.
* The number of dimensions in the shape
, either static or dynamic
.
* Same as the number of indices required to get a data value.
*/
std
::
size_t
ndim
()
const
;
...
...
@@ -214,6 +240,10 @@ struct shape
template
<
class
Iterator
>
std
::
size_t
index
(
Iterator
start
,
Iterator
last
)
const
{
if
(
this
->
dynamic
())
{
MIGRAPHX_THROW
(
"SHAPE: index() called on dynamic shape"
);
}
assert
(
std
::
distance
(
start
,
last
)
<=
this
->
lens
().
size
());
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
return
std
::
inner_product
(
start
,
last
,
this
->
strides
().
begin
(),
std
::
size_t
{
0
});
// NOLINT
...
...
@@ -222,11 +252,15 @@ struct shape
/// Map element index to space index
std
::
size_t
index
(
std
::
size_t
i
)
const
;
std
::
vector
<
std
::
size_t
>
multi
(
std
::
size_t
i
)
const
;
void
multi_copy
(
std
::
size_t
i
,
std
::
size_t
*
start
,
const
std
::
size_t
*
end
)
const
;
/// Map element index to multi-dimensional index
std
::
vector
<
std
::
size_t
>
multi
(
std
::
size_t
idx
)
const
;
/// Map element index to multi-dimensional index and put them them into location provided by
/// pointers
void
multi_copy
(
std
::
size_t
idx
,
std
::
size_t
*
start
,
const
std
::
size_t
*
end
)
const
;
/// Returns true if the shape is packed (number of elements and buffer size the same) with
no
/// padding
/// Returns true if the shape is packed (number of elements and buffer size the same) with
///
no
padding
bool
packed
()
const
;
/// Returns true is the shape has been transposed. That is the strides are not in descending
...
...
@@ -262,9 +296,9 @@ struct shape
// convert the shape to a static one setting any non-fixed dynamic_dimensions to x
shape
to_static
(
std
::
size_t
x
)
const
;
friend
bool
operator
==
(
const
shape
&
x
,
const
shape
&
y
);
friend
bool
operator
!=
(
const
shape
&
x
,
const
shape
&
y
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
shape
&
x
);
MIGRAPHX_EXPORT
friend
bool
operator
==
(
const
shape
&
x
,
const
shape
&
y
);
MIGRAPHX_EXPORT
friend
bool
operator
!=
(
const
shape
&
x
,
const
shape
&
y
);
MIGRAPHX_EXPORT
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
shape
&
x
);
template
<
class
T
>
struct
as
...
...
@@ -275,6 +309,8 @@ struct shape
type
min
()
const
{
return
std
::
numeric_limits
<
type
>::
lowest
();
}
type
nan
()
const
{
return
std
::
numeric_limits
<
type
>::
quiet_NaN
();
}
template
<
class
U
>
type
operator
()(
U
u
)
const
{
...
...
@@ -370,8 +406,8 @@ struct shape
std
::
shared_ptr
<
const
shape_impl
>
impl
;
};
void
migraphx_to_value
(
value
&
v
,
const
shape
&
s
);
void
migraphx_from_value
(
const
value
&
v
,
shape
&
s
);
MIGRAPHX_EXPORT
void
migraphx_to_value
(
value
&
v
,
const
shape
&
s
);
MIGRAPHX_EXPORT
void
migraphx_from_value
(
const
value
&
v
,
shape
&
s
);
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
...
...
src/include/migraphx/simplify_algebra.hpp
View file @
40fbef9b
...
...
@@ -35,7 +35,7 @@ struct module;
/**
* Simplify many algebraic instructions to more efficient versions.
*/
struct
simplify_algebra
struct
MIGRAPHX_EXPORT
simplify_algebra
{
std
::
string
name
()
const
{
return
"simplify_algebra"
;
}
void
apply
(
module
&
m
)
const
;
...
...
src/include/migraphx/simplify_qdq.hpp
View file @
40fbef9b
...
...
@@ -36,7 +36,7 @@ struct module;
* Inserts quantized operators in place of dq->quantizable_op->q
* then removes remaining fake quantization (q->dq pairs)
*/
struct
simplify_qdq
struct
MIGRAPHX_EXPORT
simplify_qdq
{
std
::
string
name
()
const
{
return
"simplify_qdq"
;
}
void
apply
(
module
&
m
)
const
;
...
...
src/include/migraphx/simplify_reshapes.hpp
View file @
40fbef9b
...
...
@@ -36,7 +36,7 @@ struct module;
/**
* Eliminate redundant reshapes.
*/
struct
simplify_reshapes
struct
MIGRAPHX_EXPORT
simplify_reshapes
{
std
::
string
name
()
const
{
return
"simplify_reshapes"
;
}
void
apply
(
module
&
m
)
const
;
...
...
src/include/migraphx/source_location.hpp
0 → 100644
View file @
40fbef9b
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_SOURCE_LOCATION_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_SOURCE_LOCATION_HPP
#include <migraphx/config.hpp>
#if defined(CPPCHECK)
#define MIGRAPHX_HAS_SOURCE_LOCATION 1
#define MIGRAPHX_HAS_SOURCE_LOCATION_TS 1
#elif defined(__has_include)
#if __has_include(<source_location>) && __cplusplus >= 202003L
#define MIGRAPHX_HAS_SOURCE_LOCATION 1
#else
#define MIGRAPHX_HAS_SOURCE_LOCATION 0
#endif
#if __has_include(<experimental/source_location>) && __cplusplus >= 201103L
#define MIGRAPHX_HAS_SOURCE_LOCATION_TS 1
#else
#define MIGRAPHX_HAS_SOURCE_LOCATION_TS 0
#endif
#else
#define MIGRAPHX_HAS_SOURCE_LOCATION 0
#define MIGRAPHX_HAS_SOURCE_LOCATION_TS 0
#endif
#if MIGRAPHX_HAS_SOURCE_LOCATION
#include <source_location>
#elif MIGRAPHX_HAS_SOURCE_LOCATION_TS
#include <experimental/source_location>
#endif
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
#if MIGRAPHX_HAS_SOURCE_LOCATION
using
source_location
=
std
::
source_location
;
#elif MIGRAPHX_HAS_SOURCE_LOCATION_TS
using
source_location
=
std
::
experimental
::
source_location
;
#else
struct
source_location
{
static
constexpr
source_location
current
()
noexcept
{
return
source_location
{};
}
constexpr
std
::
uint_least32_t
line
()
const
noexcept
{
return
0
;
}
constexpr
std
::
uint_least32_t
column
()
const
noexcept
{
return
0
;
}
constexpr
const
char
*
file_name
()
const
noexcept
{
return
""
;
}
constexpr
const
char
*
function_name
()
const
noexcept
{
return
""
;
}
};
#endif
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_SOURCE_LOCATION_HPP
src/include/migraphx/split_single_dyn_dim.hpp
View file @
40fbef9b
...
...
@@ -36,7 +36,7 @@ inline namespace MIGRAPHX_INLINE_NS {
* Split dynamic dimension over submodules if exactly one dimension in the parameter list is
* dynamic.
*/
struct
split_single_dyn_dim
struct
MIGRAPHX_EXPORT
split_single_dyn_dim
{
std
::
string
name
()
const
{
return
"split_single_dyn_dim"
;
}
void
apply
(
module_pass_manager
&
)
const
;
...
...
src/include/migraphx/sqlite.hpp
View file @
40fbef9b
...
...
@@ -35,7 +35,7 @@ inline namespace MIGRAPHX_INLINE_NS {
struct
sqlite_impl
;
struct
sqlite
struct
MIGRAPHX_EXPORT
sqlite
{
sqlite
()
=
default
;
static
sqlite
read
(
const
fs
::
path
&
p
);
...
...
src/include/migraphx/stream_model.hpp
View file @
40fbef9b
...
...
@@ -62,7 +62,7 @@ struct stream_model
#ifdef TYPE_ERASED_DECLARATION
// Type-erased interface for:
struct
stream_model
struct
MIGRAPHX_EXPORT
stream_model
{
//
std
::
size_t
get_nstream
()
const
;
...
...
@@ -100,7 +100,7 @@ struct stream_model
{
using
std
::
swap
;
auto
*
derived
=
this
->
any_cast
<
PrivateDetailTypeErasedT
>
();
if
(
derived
and
private_detail_te_handle_mem_var
.
u
nique
()
)
if
(
derived
and
private_detail_te_handle_mem_var
.
u
se_count
()
==
1
)
{
*
derived
=
std
::
forward
<
PrivateDetailTypeErasedT
>
(
value
);
}
...
...
@@ -288,7 +288,7 @@ struct stream_model
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
not
private_detail_te_handle_mem_var
.
u
nique
()
)
if
(
private_detail_te_handle_mem_var
.
u
se_count
()
>
1
)
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
}
...
...
src/include/migraphx/target.hpp
View file @
40fbef9b
...
...
@@ -45,6 +45,8 @@
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
value
;
#ifdef DOXYGEN
/// An interface for a compilation target
...
...
@@ -125,7 +127,7 @@ supported_segments target_find_supported(T&, const_module_ref, support_metric)
#ifdef TYPE_ERASED_DECLARATION
// Type-erased interface for:
struct
target
struct
MIGRAPHX_EXPORT
target
{
//
std
::
string
name
()
const
;
...
...
@@ -165,7 +167,7 @@ struct target
{
using
std
::
swap
;
auto
*
derived
=
this
->
any_cast
<
PrivateDetailTypeErasedT
>
();
if
(
derived
and
private_detail_te_handle_mem_var
.
u
nique
()
)
if
(
derived
and
private_detail_te_handle_mem_var
.
u
se_count
()
==
1
)
{
*
derived
=
std
::
forward
<
PrivateDetailTypeErasedT
>
(
value
);
}
...
...
@@ -426,7 +428,7 @@ struct target
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
not
private_detail_te_handle_mem_var
.
u
nique
()
)
if
(
private_detail_te_handle_mem_var
.
u
se_count
()
>
1
)
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
}
...
...
@@ -467,6 +469,9 @@ inline const ValueType& any_cast(const target& x)
#endif
void
migraphx_to_value
(
value
&
v
,
const
target
&
t
);
void
migraphx_from_value
(
const
value
&
v
,
target
&
t
);
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
...
...
src/include/migraphx/tf.hpp
View file @
40fbef9b
...
...
@@ -26,6 +26,7 @@
#include <migraphx/program.hpp>
#include <migraphx/config.hpp>
#include <migraphx/tf/export.h>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -41,7 +42,10 @@ struct tf_options
};
/// Create a program from a tf pb file (default is nhwc format)
program
parse_tf
(
const
std
::
string
&
name
,
const
tf_options
&
options
=
tf_options
{});
MIGRAPHX_TF_EXPORT
program
parse_tf
(
const
std
::
string
&
name
,
const
tf_options
&
options
=
tf_options
{});
MIGRAPHX_TF_EXPORT
std
::
vector
<
std
::
string
>
get_tf_operators
();
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
...
...
src/include/migraphx/tmp_dir.hpp
View file @
40fbef9b
...
...
@@ -30,7 +30,7 @@
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
tmp_dir
struct
MIGRAPHX_EXPORT
tmp_dir
{
fs
::
path
path
;
tmp_dir
(
const
std
::
string
&
prefix
=
""
);
...
...
src/include/migraphx/value.hpp
View file @
40fbef9b
...
...
@@ -32,6 +32,7 @@
#include <algorithm>
#include <cassert>
#include <memory>
#include <cstdint>
#include <sstream>
#include <type_traits>
#include <tuple>
...
...
@@ -140,7 +141,7 @@ To try_convert_value(const From& x)
return
detail
::
try_convert_value_impl
<
To
>
(
rank
<
3
>
{},
x
);
}
struct
value
struct
MIGRAPHX_EXPORT
value
{
// clang-format off
#define MIGRAPHX_VISIT_VALUE_TYPES(m) \
...
...
@@ -392,8 +393,8 @@ struct value
return; \
}
MIGRAPHX_VISIT_VALUE_TYPES
(
MIGRAPHX_VALUE_GENERATE_CASE_VALUE
)
MIGRAPHX_VALUE_GENERATE_CASE
(
array
,
)
MIGRAPHX_VALUE_GENERATE_CASE
(
object
,
)
MIGRAPHX_VALUE_GENERATE_CASE
_VALUE
(
array
,
)
MIGRAPHX_VALUE_GENERATE_CASE
_VALUE
(
object
,
)
}
MIGRAPHX_THROW
(
"Unknown type"
);
}
...
...
@@ -452,14 +453,16 @@ struct value
std
::
vector
<
literal_to_string
<
To
>>
{
default_value
.
begin
(),
default_value
.
end
()});
}
friend
bool
operator
==
(
const
value
&
x
,
const
value
&
y
);
friend
bool
operator
!=
(
const
value
&
x
,
const
value
&
y
);
friend
bool
operator
<
(
const
value
&
x
,
const
value
&
y
);
friend
bool
operator
<=
(
const
value
&
x
,
const
value
&
y
);
friend
bool
operator
>
(
const
value
&
x
,
const
value
&
y
);
friend
bool
operator
>=
(
const
value
&
x
,
const
value
&
y
);
MIGRAPHX_EXPORT
friend
bool
operator
==
(
const
value
&
x
,
const
value
&
y
);
MIGRAPHX_EXPORT
friend
bool
operator
!=
(
const
value
&
x
,
const
value
&
y
);
MIGRAPHX_EXPORT
friend
bool
operator
<
(
const
value
&
x
,
const
value
&
y
);
MIGRAPHX_EXPORT
friend
bool
operator
<=
(
const
value
&
x
,
const
value
&
y
);
MIGRAPHX_EXPORT
friend
bool
operator
>
(
const
value
&
x
,
const
value
&
y
);
MIGRAPHX_EXPORT
friend
bool
operator
>=
(
const
value
&
x
,
const
value
&
y
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
value
&
d
);
MIGRAPHX_EXPORT
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
value
&
d
);
std
::
size_t
hash
()
const
;
void
debug_print
(
bool
show_type
=
false
)
const
;
...
...
@@ -481,4 +484,15 @@ struct value
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
namespace
std
{
template
<
>
struct
hash
<
migraphx
::
value
>
{
using
argument_type
=
migraphx
::
value
;
using
result_type
=
std
::
size_t
;
result_type
operator
()(
const
migraphx
::
value
&
x
)
const
{
return
x
.
hash
();
}
};
}
// namespace std
#endif
src/include/migraphx/verify.hpp
View file @
40fbef9b
...
...
@@ -35,6 +35,7 @@
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
verify
{
// Compute the value of a range
template
<
class
R
>
...
...
@@ -196,6 +197,7 @@ bool verify_range(const R1& r1, const R2& r2, double tolerance = 80, double* out
return
error
<=
threshold
;
}
}
// namespace verify
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/verify_args.hpp
View file @
40fbef9b
...
...
@@ -31,6 +31,7 @@
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
MIGRAPHX_EXPORT
bool
verify_args
(
const
std
::
string
&
name
,
const
argument
&
ref_arg
,
const
argument
&
target_arg
,
...
...
src/instruction.cpp
100755 → 100644
View file @
40fbef9b
...
...
@@ -64,10 +64,7 @@ void instruction::replace(const shape& r)
result
=
r
;
for
(
auto
&&
ins
:
output
)
{
if
(
ins
->
name
()
==
"@return"
)
continue
;
assert
(
ins
->
name
().
front
()
!=
'@'
);
assert
(
ins
->
name
()
==
"@return"
or
ins
->
name
().
front
()
!=
'@'
);
ins
->
recompute_shape
();
}
}
...
...
@@ -122,10 +119,6 @@ bool instruction::valid() const
{
computed
=
result
;
}
else
if
(
op
.
name
()
==
"@return"
)
{
computed
=
{};
}
else
{
try
...
...
@@ -145,6 +138,7 @@ bool instruction::valid() const
}
shape
instruction
::
get_shape
()
const
{
return
result
;
}
const
literal
&
instruction
::
get_literal
()
const
{
assert
(
op
.
name
()
==
"@literal"
);
...
...
@@ -406,6 +400,9 @@ void instruction::print(std::ostream& os,
// skip return instruction shape
if
(
ins
->
name
()
!=
"@return"
)
os
<<
" -> "
<<
ins
->
get_shape
();
// print tid
os
<<
", target_id="
<<
ins
->
target_id
;
}
static
void
debug_name
(
std
::
ostream
&
os
,
const
instruction
&
ins
)
...
...
@@ -464,11 +461,14 @@ operation instruction::normalized_operator() const
if
(
this
->
need_normalization
())
{
auto
s
=
this
->
inputs
().
front
()
->
get_shape
();
if
(
not
normalize_attributes
(
o
,
s
.
max_lens
()
))
if
(
not
normalize_attributes
(
o
,
s
))
return
this
->
get_operator
();
}
return
o
;
}
std
::
size_t
instruction
::
get_target_id
()
const
{
return
target_id
;
}
void
instruction
::
set_target_id
(
std
::
size_t
tid
)
{
this
->
target_id
=
tid
;
}
std
::
vector
<
shape
>
to_shapes
(
const
std
::
vector
<
instruction_ref
>&
args
)
{
...
...
Prev
1
…
5
6
7
8
9
10
11
12
13
…
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment