Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
baac1dab
Commit
baac1dab
authored
May 24, 2023
by
Alan Turner
Browse files
Merge remote-tracking branch 'origin/develop' into ck-host-lib
parents
830dff7a
77042e30
Changes
299
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
197 additions
and
54 deletions
+197
-54
src/include/migraphx/auto_any_cast.hpp
src/include/migraphx/auto_any_cast.hpp
+1
-1
src/include/migraphx/check_context.hpp
src/include/migraphx/check_context.hpp
+23
-1
src/include/migraphx/check_shapes.hpp
src/include/migraphx/check_shapes.hpp
+20
-3
src/include/migraphx/common.hpp
src/include/migraphx/common.hpp
+5
-0
src/include/migraphx/compile_options.hpp
src/include/migraphx/compile_options.hpp
+7
-1
src/include/migraphx/config.hpp
src/include/migraphx/config.hpp
+19
-9
src/include/migraphx/context.hpp
src/include/migraphx/context.hpp
+3
-1
src/include/migraphx/cpp_generator.hpp
src/include/migraphx/cpp_generator.hpp
+6
-0
src/include/migraphx/dynamic_loader.hpp
src/include/migraphx/dynamic_loader.hpp
+6
-0
src/include/migraphx/fuse_reduce.hpp
src/include/migraphx/fuse_reduce.hpp
+11
-6
src/include/migraphx/half.hpp
src/include/migraphx/half.hpp
+1
-1
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+21
-1
src/include/migraphx/memory_coloring.hpp
src/include/migraphx/memory_coloring.hpp
+2
-1
src/include/migraphx/module.hpp
src/include/migraphx/module.hpp
+2
-0
src/include/migraphx/msgpack.hpp
src/include/migraphx/msgpack.hpp
+2
-0
src/include/migraphx/onnx.hpp
src/include/migraphx/onnx.hpp
+1
-1
src/include/migraphx/op/allocate.hpp
src/include/migraphx/op/allocate.hpp
+1
-1
src/include/migraphx/op/argmax.hpp
src/include/migraphx/op/argmax.hpp
+1
-1
src/include/migraphx/op/concat.hpp
src/include/migraphx/op/concat.hpp
+64
-25
src/include/migraphx/op/contiguous.hpp
src/include/migraphx/op/contiguous.hpp
+1
-1
No files found.
src/include/migraphx/auto_any_cast.hpp
View file @
baac1dab
...
@@ -42,7 +42,7 @@ void any_cast()
...
@@ -42,7 +42,7 @@ void any_cast()
template
<
class
T
>
template
<
class
T
>
struct
auto_any_caster
struct
auto_any_caster
{
{
T
&
x
;
T
&
x
;
// NOLINT
template
<
class
U
>
template
<
class
U
>
operator
U
&
()
operator
U
&
()
...
...
src/include/migraphx/check_context.hpp
View file @
baac1dab
...
@@ -27,6 +27,8 @@
...
@@ -27,6 +27,8 @@
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/ranges.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -36,7 +38,27 @@ struct check_context
...
@@ -36,7 +38,27 @@ struct check_context
{
{
struct
op
:
auto_register_op
<
op
>
struct
op
:
auto_register_op
<
op
>
{
{
std
::
string
name
()
const
{
return
"check_context::"
+
get_type_name
<
T
>
();
}
static
std
::
string
compute_op_name
()
{
const
auto
&
op_type_name
=
get_type_name
<
T
>
();
const
auto
&
split_name
=
split_string
(
op_type_name
,
':'
);
std
::
vector
<
std
::
string
>
name_without_version
=
{
"check_context"
};
// op_type_name would contain internal namespace name with version_x_y_z
// remove version and construct op_name such as check_context::migraphx::gpu::context
std
::
copy_if
(
split_name
.
begin
(),
split_name
.
end
(),
std
::
back_inserter
(
name_without_version
),
[
&
](
const
auto
&
i
)
{
return
not
i
.
empty
()
and
not
contains
(
i
,
"version"
);
});
return
join_strings
(
name_without_version
,
"::"
);
}
std
::
string
name
()
const
{
static
auto
op_name
=
compute_op_name
();
return
op_name
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
)
const
{
return
{};
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
)
const
{
return
{};
}
argument
compute
(
context
&
ctx
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
const
argument
compute
(
context
&
ctx
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
const
{
{
...
...
src/include/migraphx/check_shapes.hpp
View file @
baac1dab
...
@@ -38,8 +38,8 @@ struct check_shapes
...
@@ -38,8 +38,8 @@ struct check_shapes
{
{
const
shape
*
begin
;
const
shape
*
begin
;
const
shape
*
end
;
const
shape
*
end
;
const
std
::
string
name
;
std
::
string
name
;
const
bool
dynamic_allowed
;
bool
dynamic_allowed
;
check_shapes
(
const
shape
*
b
,
const
shape
*
e
,
const
std
::
string
&
n
,
const
bool
d
=
false
)
check_shapes
(
const
shape
*
b
,
const
shape
*
e
,
const
std
::
string
&
n
,
const
bool
d
=
false
)
:
begin
(
b
),
end
(
e
),
name
(
n
),
dynamic_allowed
(
d
)
:
begin
(
b
),
end
(
e
),
name
(
n
),
dynamic_allowed
(
d
)
...
@@ -87,7 +87,7 @@ struct check_shapes
...
@@ -87,7 +87,7 @@ struct check_shapes
}
}
/*!
/*!
*
Check if
the number of shape objects
is
equal to
atleast
one of the
*
Require
the number of shape objects
to
equal to one of the
* given sizes.
* given sizes.
* \param ns template parameter pack of sizes to check against
* \param ns template parameter pack of sizes to check against
*/
*/
...
@@ -100,6 +100,23 @@ struct check_shapes
...
@@ -100,6 +100,23 @@ struct check_shapes
return
*
this
;
return
*
this
;
}
}
/*!
* Require the number of shape objects to equal at least a given amount. Use this
* method for ops that can take any number (variadic) of inputs.
* \param n min. number of shapes
*/
const
check_shapes
&
has_at_least
(
std
::
size_t
n
)
const
{
if
(
this
->
size
()
<
n
)
MIGRAPHX_THROW
(
prefix
()
+
"Wrong number of arguments: expected at least "
+
to_string
(
n
)
+
" but given "
+
std
::
to_string
(
size
()));
return
*
this
;
}
/*!
* Require all shapes to have the same number of elements.
* \param n number of
*/
const
check_shapes
&
nelements
(
std
::
size_t
n
)
const
const
check_shapes
&
nelements
(
std
::
size_t
n
)
const
{
{
if
(
not
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
s
.
elements
()
==
n
;
}))
if
(
not
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
s
.
elements
()
==
n
;
}))
...
...
src/include/migraphx/common.hpp
View file @
baac1dab
...
@@ -41,6 +41,11 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha
...
@@ -41,6 +41,11 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha
shape
common_shape
(
const
std
::
vector
<
shape
>&
shapes
);
shape
common_shape
(
const
std
::
vector
<
shape
>&
shapes
);
std
::
vector
<
instruction_ref
>
insert_common_args
(
module
&
m
,
instruction_ref
ins
,
std
::
vector
<
instruction_ref
>
inputs
);
std
::
vector
<
instruction_ref
>
add_common_args
(
module
&
m
,
std
::
vector
<
instruction_ref
>
inputs
);
instruction_ref
insert_common_op
(
module
&
m
,
instruction_ref
insert_common_op
(
module
&
m
,
instruction_ref
ins
,
instruction_ref
ins
,
const
operation
&
op
,
const
operation
&
op
,
...
...
src/include/migraphx/compile_options.hpp
View file @
baac1dab
...
@@ -32,8 +32,14 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -32,8 +32,14 @@ inline namespace MIGRAPHX_INLINE_NS {
struct
compile_options
struct
compile_options
{
{
/**
* Have MIGX allocate memory for parameters and add instructions
* to copy parameters and output to/from an offload device like a GPU.
*/
bool
offload_copy
=
false
;
bool
offload_copy
=
false
;
bool
fast_math
=
true
;
bool
fast_math
=
true
;
bool
exhaustive_tune
=
false
;
tracer
trace
{};
tracer
trace
{};
};
};
...
...
src/include/migraphx/config.hpp
View file @
baac1dab
...
@@ -24,22 +24,32 @@
...
@@ -24,22 +24,32 @@
#ifndef MIGRAPHX_GUARD_CONFIG_HPP
#ifndef MIGRAPHX_GUARD_CONFIG_HPP
#define MIGRAPHX_GUARD_CONFIG_HPP
#define MIGRAPHX_GUARD_CONFIG_HPP
namespace
migraphx
{
#if !defined(MIGRAPHX_USE_CLANG_TIDY) && !defined(DOXYGEN)
#if !defined(MIGRAPHX_USE_CLANG_TIDY) && !defined(DOXYGEN)
#ifdef BUILD_DEV
#define MIGRAPHX_INLINE_NS version_1
#define MIGRAPHX_INLINE_NS version_1
#endif
#else
#include <migraphx/version.h>
#define MIGRAPHX_VERSION_PRIMITIVE_CONCAT(x, y) x##_##y
#define MIGRAPHX_VERSION_CONCAT(x, y) MIGRAPHX_VERSION_PRIMITIVE_CONCAT(x, y)
#define MIGRAPHX_VERSION \
MIGRAPHX_VERSION_CONCAT( \
MIGRAPHX_VERSION_CONCAT(MIGRAPHX_VERSION_MAJOR, MIGRAPHX_VERSION_MINOR), \
MIGRAPHX_VERSION_PATCH)
#define MIGRAPHX_INLINE_NS MIGRAPHX_VERSION_CONCAT(version, MIGRAPHX_VERSION)
#endif // build_dev
#endif // clang_tidy
#ifdef DOXYGEN
#ifdef DOXYGEN
#define MIGRAPHX_INLINE_NS internal
#define MIGRAPHX_INLINE_NS internal
#endif
#endif
// doxygen
#ifdef MIGRAPHX_USE_CLANG_TIDY
#ifdef MIGRAPHX_USE_CLANG_TIDY
#define MIGRAPHX_TIDY_CONST const
#define MIGRAPHX_TIDY_CONST const
#else
#else
#define MIGRAPHX_TIDY_CONST
#define MIGRAPHX_TIDY_CONST
#endif
#endif // tidy_const
#endif // clang_tidy
}
// namespace migraphx
#endif
src/include/migraphx/context.hpp
View file @
baac1dab
...
@@ -66,6 +66,7 @@ any_ptr get_queue_context(T&)
...
@@ -66,6 +66,7 @@ any_ptr get_queue_context(T&)
{
{
return
{};
return
{};
}
}
template
<
class
T
>
template
<
class
T
>
void
wait_for_context
(
T
&
,
any_ptr
)
void
wait_for_context
(
T
&
,
any_ptr
)
{
{
...
@@ -302,7 +303,7 @@ struct context
...
@@ -302,7 +303,7 @@ struct context
PrivateDetailTypeErasedT
value
,
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
value
)
:
private_detail_te_value
(
std
::
move
(
value
)
)
{
{
}
}
...
@@ -412,6 +413,7 @@ inline const ValueType& any_cast(const context& x)
...
@@ -412,6 +413,7 @@ inline const ValueType& any_cast(const context& x)
#endif
#endif
inline
void
migraphx_to_value
(
value
&
v
,
const
context
&
ctx
)
{
v
=
ctx
.
to_value
();
}
inline
void
migraphx_to_value
(
value
&
v
,
const
context
&
ctx
)
{
v
=
ctx
.
to_value
();
}
inline
void
migraphx_from_value
(
const
value
&
v
,
context
&
ctx
)
{
ctx
.
from_value
(
v
);
}
inline
void
migraphx_from_value
(
const
value
&
v
,
context
&
ctx
)
{
ctx
.
from_value
(
v
);
}
#endif
#endif
...
...
src/include/migraphx/cpp_generator.hpp
View file @
baac1dab
...
@@ -77,6 +77,8 @@ struct cpp_generator
...
@@ -77,6 +77,8 @@ struct cpp_generator
function
&
set_types
(
const
module
&
m
);
function
&
set_types
(
const
module
&
m
);
function
&
set_types
(
const
module
&
m
,
const
std
::
function
<
std
::
string
(
shape
)
>&
parse
);
function
&
set_types
(
const
module
&
m
,
const
std
::
function
<
std
::
string
(
shape
)
>&
parse
);
function
&
set_generic_types
(
const
module
&
m
);
function
&
set_generic_types
(
const
module
&
m
);
function
&
add_generic_param
(
const
std
::
string
&
pname
);
function
&
unused_param
(
const
std
::
string
&
pname
);
};
};
cpp_generator
();
cpp_generator
();
...
@@ -105,6 +107,10 @@ struct cpp_generator
...
@@ -105,6 +107,10 @@ struct cpp_generator
std
::
string
create_function
(
const
function
&
f
);
std
::
string
create_function
(
const
function
&
f
);
static
std
::
vector
<
std
::
string
>
to_args
(
const
std
::
vector
<
instruction_ref
>&
inputs
,
const
std
::
unordered_map
<
instruction_ref
,
std
::
string
>&
names
);
private:
private:
std
::
unique_ptr
<
cpp_generator_impl
>
impl
;
std
::
unique_ptr
<
cpp_generator_impl
>
impl
;
};
};
...
...
src/include/migraphx/dynamic_loader.hpp
View file @
baac1dab
...
@@ -37,6 +37,12 @@ struct dynamic_loader_impl;
...
@@ -37,6 +37,12 @@ struct dynamic_loader_impl;
struct
dynamic_loader
struct
dynamic_loader
{
{
template
<
class
T
>
static
fs
::
path
path
(
T
*
address
)
{
return
path
(
reinterpret_cast
<
void
*>
(
address
));
}
static
fs
::
path
path
(
void
*
address
);
dynamic_loader
()
=
default
;
dynamic_loader
()
=
default
;
dynamic_loader
(
const
fs
::
path
&
p
);
dynamic_loader
(
const
fs
::
path
&
p
);
...
...
src/include/migraphx/
pass_config
.hpp
→
src/include/migraphx/
fuse_reduce
.hpp
View file @
baac1dab
...
@@ -21,18 +21,23 @@
...
@@ -21,18 +21,23 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
* THE SOFTWARE.
*/
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_FUSE_REDUCE_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_FUSE_REDUCE_HPP
#ifndef MIGRAPHX_GUARD_PASS_CONFIG_HPP
#define MIGRAPHX_GUARD_PASS_CONFIG_HPP
#include <migraphx/env.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <string>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DISABLE_MEMORY_COLORING
)
struct
module_pass_manager
;
struct
fuse_reduce
{
std
::
string
name
()
const
{
return
"fuse_reduce"
;
}
void
apply
(
module_pass_manager
&
mpm
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_
PASS_CONFIG
_HPP
#endif // MIGRAPHX_GUARD_
MIGRAPHX_FUSE_POINTWISE
_HPP
src/include/migraphx/half.hpp
View file @
baac1dab
...
@@ -25,7 +25,7 @@
...
@@ -25,7 +25,7 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_HALF_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_HALF_HPP
#define MIGRAPHX_GUARD_RTGLIB_HALF_HPP
#define MIGRAPHX_GUARD_RTGLIB_HALF_HPP
#include <half.hpp>
#include <half
/half
.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
namespace
migraphx
{
...
...
src/include/migraphx/matcher.hpp
View file @
baac1dab
...
@@ -347,6 +347,7 @@ match::matcher_result find_match(module& modl, M&& m)
...
@@ -347,6 +347,7 @@ match::matcher_result find_match(module& modl, M&& m)
}
}
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_MATCHES
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_MATCHES
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_VALIDATE_MATCHES
)
/// Find matches for an instruction in the module
/// Find matches for an instruction in the module
template
<
class
Mod
,
class
...
Ms
>
template
<
class
Mod
,
class
...
Ms
>
...
@@ -356,7 +357,11 @@ void find_matches(Mod& mod, instruction_ref ins, Ms&&... ms)
...
@@ -356,7 +357,11 @@ void find_matches(Mod& mod, instruction_ref ins, Ms&&... ms)
const
const
#endif
#endif
int
trace
=
value_of
(
MIGRAPHX_TRACE_MATCHES
{});
int
trace
=
value_of
(
MIGRAPHX_TRACE_MATCHES
{});
bool
match
=
false
;
#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5
const
#endif
bool
validate
=
enabled
(
MIGRAPHX_VALIDATE_MATCHES
{});
bool
match
=
false
;
each_args
(
each_args
(
[
&
](
auto
&&
m
)
{
[
&
](
auto
&&
m
)
{
if
(
match
)
if
(
match
)
...
@@ -371,7 +376,20 @@ void find_matches(Mod& mod, instruction_ref ins, Ms&&... ms)
...
@@ -371,7 +376,20 @@ void find_matches(Mod& mod, instruction_ref ins, Ms&&... ms)
std
::
cout
<<
"Matched by "
<<
get_type_name
(
m
)
<<
std
::
endl
;
std
::
cout
<<
"Matched by "
<<
get_type_name
(
m
)
<<
std
::
endl
;
get_module
(
mod
).
debug_print
(
ins
);
get_module
(
mod
).
debug_print
(
ins
);
}
}
// If its already invalid dont validate it again
bool
invalidated
=
validate
and
get_module
(
mod
).
validate
()
!=
get_module
(
mod
).
end
();
m
.
apply
(
mod
,
r
);
m
.
apply
(
mod
,
r
);
if
(
validate
and
not
invalidated
)
{
auto
invalid
=
get_module
(
mod
).
validate
();
if
(
invalid
!=
get_module
(
mod
).
end
())
{
std
::
cout
<<
"Invalid program from match: "
<<
get_type_name
(
m
)
<<
std
::
endl
;
std
::
cout
<<
"Invalid instructions: "
<<
std
::
endl
;
get_module
(
mod
).
debug_print
(
invalid
->
inputs
());
get_module
(
mod
).
debug_print
(
invalid
);
}
}
match
=
true
;
match
=
true
;
},
},
ms
...);
ms
...);
...
@@ -520,6 +538,8 @@ MIGRAPHX_PRED_MATCHER(not_standard_shape, instruction_ref ins)
...
@@ -520,6 +538,8 @@ MIGRAPHX_PRED_MATCHER(not_standard_shape, instruction_ref ins)
{
{
return
not
ins
->
get_shape
().
standard
();
return
not
ins
->
get_shape
().
standard
();
}
}
MIGRAPHX_PRED_MATCHER
(
dynamic_shape
,
instruction_ref
ins
)
{
return
ins
->
get_shape
().
dynamic
();
}
MIGRAPHX_PRED_MATCHER
(
static_shape
,
instruction_ref
ins
)
{
return
not
ins
->
get_shape
().
dynamic
();
}
MIGRAPHX_PRED_MATCHER
(
broadcast_shape
,
instruction_ref
ins
)
MIGRAPHX_PRED_MATCHER
(
broadcast_shape
,
instruction_ref
ins
)
{
{
return
ins
->
get_shape
().
broadcasted
();
return
ins
->
get_shape
().
broadcasted
();
...
...
src/include/migraphx/memory_coloring.hpp
View file @
baac1dab
...
@@ -33,7 +33,8 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -33,7 +33,8 @@ inline namespace MIGRAPHX_INLINE_NS {
struct
module
;
struct
module
;
/**
/**
* Remove memory allocations. It uses graph coloring to find memory allocations that can be reused.
* Remove multiple memory allocations using graph coloring to find memory allocations that can be
* reused.
*/
*/
struct
memory_coloring
struct
memory_coloring
{
{
...
...
src/include/migraphx/module.hpp
View file @
baac1dab
...
@@ -178,6 +178,8 @@ struct module
...
@@ -178,6 +178,8 @@ struct module
bool
has_instruction
(
instruction_ref
ins
)
const
;
bool
has_instruction
(
instruction_ref
ins
)
const
;
std
::
vector
<
instruction_ref
>
get_returns
()
const
;
std
::
size_t
size
()
const
;
std
::
size_t
size
()
const
;
instruction_ref
begin
()
const
;
instruction_ref
begin
()
const
;
instruction_ref
end
()
const
;
instruction_ref
end
()
const
;
...
...
src/include/migraphx/msgpack.hpp
View file @
baac1dab
...
@@ -26,10 +26,12 @@
...
@@ -26,10 +26,12 @@
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/value.hpp>
#include <functional>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
to_msgpack
(
const
value
&
v
,
std
::
function
<
void
(
const
char
*
,
std
::
size_t
)
>
writer
);
std
::
vector
<
char
>
to_msgpack
(
const
value
&
v
);
std
::
vector
<
char
>
to_msgpack
(
const
value
&
v
);
value
from_msgpack
(
const
std
::
vector
<
char
>&
buffer
);
value
from_msgpack
(
const
std
::
vector
<
char
>&
buffer
);
value
from_msgpack
(
const
char
*
buffer
,
std
::
size_t
size
);
value
from_msgpack
(
const
char
*
buffer
,
std
::
size_t
size
);
...
...
src/include/migraphx/onnx.hpp
View file @
baac1dab
...
@@ -37,7 +37,7 @@ struct onnx_options
...
@@ -37,7 +37,7 @@ struct onnx_options
std
::
size_t
default_dim_value
=
0
;
std
::
size_t
default_dim_value
=
0
;
/// Default dynamic dimension size (if both default_dim_value and default_dyn_dim_value set
/// Default dynamic dimension size (if both default_dim_value and default_dyn_dim_value set
/// parser throws)
/// parser throws)
shape
::
dynamic_dimension
default_dyn_dim_value
=
{
1
,
1
,
0
};
shape
::
dynamic_dimension
default_dyn_dim_value
=
{
1
,
1
};
/// Explicitly specify the dims of an input
/// Explicitly specify the dims of an input
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
map_input_dims
=
{};
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
map_input_dims
=
{};
/// Explicitly specify dynamic dims of an input (if both map_input_dims and map_dyn_input_dims
/// Explicitly specify dynamic dims of an input (if both map_input_dims and map_dyn_input_dims
...
...
src/include/migraphx/op/allocate.hpp
View file @
baac1dab
...
@@ -44,7 +44,7 @@ struct allocate
...
@@ -44,7 +44,7 @@ struct allocate
std
::
string
name
()
const
{
return
"allocate"
;
}
std
::
string
name
()
const
{
return
"allocate"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
{
migraphx
::
check_shapes
{
inputs
,
*
this
}.
has
(
0
);
migraphx
::
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
0
);
return
s
;
return
s
;
}
}
argument
compute
(
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
)
const
argument
compute
(
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
)
const
...
...
src/include/migraphx/op/argmax.hpp
View file @
baac1dab
...
@@ -62,7 +62,7 @@ struct argmax
...
@@ -62,7 +62,7 @@ struct argmax
if
(
s0
.
dynamic
())
if
(
s0
.
dynamic
())
{
{
auto
dyn_dims
=
s0
.
dyn_dims
();
auto
dyn_dims
=
s0
.
dyn_dims
();
dyn_dims
[
axis
]
=
{
1
,
1
,
0
};
dyn_dims
[
axis
]
=
{
1
,
1
};
return
{
shape
::
int64_type
,
dyn_dims
};
return
{
shape
::
int64_type
,
dyn_dims
};
}
}
else
else
...
...
src/include/migraphx/op/concat.hpp
View file @
baac1dab
...
@@ -26,6 +26,7 @@
...
@@ -26,6 +26,7 @@
#include <array>
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/literal.hpp>
...
@@ -73,49 +74,87 @@ struct concat
...
@@ -73,49 +74,87 @@ struct concat
}
}
return
offsets
;
return
offsets
;
}
}
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
if
(
inputs
.
empty
())
// inputs can contain 1 or more shapes (variadic). compute_shape_op ensures there must
// be at least 1.
check_shapes
{
inputs
,
*
this
,
true
}.
same_ndims
().
same_type
();
if
(
std
::
none_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
const
shape
&
s
)
{
return
s
.
dynamic
();
}))
{
{
MIGRAPHX_THROW
(
"CONCAT: Number of input tensors should exceed 0"
);
// Static input shapes
const
auto
&
first_shape_lens
=
inputs
.
front
().
lens
();
const
auto
&
type
=
inputs
.
front
().
type
();
for
(
std
::
size_t
ll
=
0
;
ll
<
first_shape_lens
.
size
();
ll
++
)
{
if
(
ll
!=
axis
)
{
if
(
not
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
auto
s
)
{
return
s
.
lens
()[
ll
]
==
first_shape_lens
[
ll
];
}))
{
MIGRAPHX_THROW
(
"CONCAT: all input dimensions should match along axis "
+
std
::
to_string
(
ll
));
}
}
}
std
::
size_t
new_dim_axis
=
0
;
for
(
const
auto
&
input
:
inputs
)
{
const
auto
&
lens
=
input
.
lens
();
new_dim_axis
+=
lens
[
axis
];
}
std
::
vector
<
std
::
size_t
>
new_lens
=
first_shape_lens
;
new_lens
[
axis
]
=
new_dim_axis
;
return
shape
::
from_permutation
(
type
,
new_lens
,
find_permutation
(
inputs
));
}
}
else
if
(
std
::
all_of
(
const
auto
&
first_shape_lens
=
inputs
.
front
().
lens
();
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
const
shape
&
s
)
{
return
s
.
dynamic
();
}))
const
auto
&
type
=
inputs
.
front
().
type
();
for
(
std
::
size_t
l
=
0
;
l
<
first_shape_lens
.
size
();
l
++
)
{
{
if
(
l
!=
axis
)
// Dynamic input shapes
for
(
std
::
size_t
index
=
0
;
index
<
inputs
[
0
].
ndim
();
index
++
)
{
{
if
(
not
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
auto
s
)
{
if
(
index
!=
axis
)
return
s
.
lens
()[
l
]
==
first_shape_lens
[
l
];
}))
{
{
MIGRAPHX_THROW
(
"CONCAT: Non-axis dimensions should match"
);
if
(
not
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
const
shape
&
s
)
{
return
s
.
dyn_dims
()[
index
]
==
inputs
[
0
].
dyn_dims
()[
index
];
}))
MIGRAPHX_THROW
(
"CONCAT: all input dimensions should match in axis "
+
std
::
to_string
(
index
));
}
}
}
}
std
::
size_t
new_min
=
0
;
std
::
size_t
new_max
=
0
;
for
(
const
auto
&
input
:
inputs
)
{
auto
ddim
=
input
.
dyn_dims
()[
axis
];
new_min
+=
ddim
.
min
;
new_max
+=
ddim
.
max
;
}
auto
new_dims
=
inputs
[
0
].
dyn_dims
();
new_dims
[
axis
]
=
migraphx
::
shape
::
dynamic_dimension
{
new_min
,
new_max
};
return
{
inputs
[
0
].
type
(),
new_dims
};
}
}
std
::
size_t
new_dim_axis
=
0
;
else
for
(
const
auto
&
input
:
inputs
)
{
{
const
auto
&
lens
=
input
.
lens
();
MIGRAPHX_THROW
(
"CONCAT: Cannot mix static and dynamic input shapes."
);
new_dim_axis
+=
lens
[
axis
];
}
}
std
::
vector
<
std
::
size_t
>
new_lens
;
std
::
copy
(
first_shape_lens
.
begin
(),
first_shape_lens
.
end
(),
std
::
back_inserter
(
new_lens
));
new_lens
[
axis
]
=
new_dim_axis
;
return
shape
::
from_permutation
(
type
,
new_lens
,
find_permutation
(
inputs
));
}
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
out
put_shape
};
argument
result
{
dyn_out
.
com
put
ed
_shape
};
std
::
vector
<
std
::
size_t
>
coffsets
=
compute_offsets
(
out
put_shape
,
args
);
std
::
vector
<
std
::
size_t
>
coffsets
=
compute_offsets
(
dyn_out
.
com
put
ed
_shape
,
args
);
for
(
std
::
size_t
l
=
0
;
l
<
args
.
size
();
l
++
)
for
(
std
::
size_t
l
=
0
;
l
<
args
.
size
();
l
++
)
{
{
auto
argl
=
args
[
l
];
auto
argl
=
args
[
l
];
visit_all
(
result
,
argl
)([
&
](
auto
output
,
auto
input
)
{
visit_all
(
result
,
argl
)([
&
](
auto
output
,
auto
input
)
{
auto
slice_shape
=
auto
slice_shape
=
shape
{
dyn_out
.
computed_shape
.
type
(),
shape
{
output_shape
.
type
(),
input
.
get_shape
().
lens
(),
output_shape
.
strides
()};
input
.
get_shape
().
lens
(),
auto
slice
=
make_view
(
slice_shape
,
output
.
data
()
+
coffsets
[
l
]);
dyn_out
.
computed_shape
.
strides
()};
auto
slice
=
make_view
(
slice_shape
,
output
.
data
()
+
coffsets
[
l
]);
std
::
copy
(
input
.
begin
(),
input
.
end
(),
slice
.
begin
());
std
::
copy
(
input
.
begin
(),
input
.
end
(),
slice
.
begin
());
});
});
}
}
...
...
src/include/migraphx/op/contiguous.hpp
View file @
baac1dab
...
@@ -48,7 +48,7 @@ struct contiguous
...
@@ -48,7 +48,7 @@ struct contiguous
{
{
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
auto
s0
=
inputs
.
front
();
auto
s0
=
inputs
.
front
();
if
(
s0
.
dynamic
()
or
s0
.
standard
()
)
if
(
s0
.
dynamic
())
{
{
return
s0
;
return
s0
;
}
}
...
...
Prev
1
2
3
4
5
6
7
8
9
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment