Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
4957715b
Commit
4957715b
authored
May 11, 2022
by
turneram
Browse files
Merge remote-tracking branch 'origin/develop' into dev2
parents
f99a3036
4ec8209f
Changes
63
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
367 additions
and
149 deletions
+367
-149
.clang-tidy
.clang-tidy
+1
-1
.github/workflows/ci.yaml
.github/workflows/ci.yaml
+6
-5
CMakeLists.txt
CMakeLists.txt
+10
-3
Dockerfile
Dockerfile
+2
-6
cmake/Embed.cmake
cmake/Embed.cmake
+1
-0
cmake/EnableCompilerWarnings.cmake
cmake/EnableCompilerWarnings.cmake
+1
-0
doc/src/reference/py.rst
doc/src/reference/py.rst
+7
-0
examples/nlp/python_bert_squad/BERT-Squad.ipynb
examples/nlp/python_bert_squad/BERT-Squad.ipynb
+1
-1
examples/nlp/python_bert_squad/README.md
examples/nlp/python_bert_squad/README.md
+1
-1
hip-clang.docker
hip-clang.docker
+2
-5
src/CMakeLists.txt
src/CMakeLists.txt
+1
-0
src/api/api.cpp
src/api/api.cpp
+32
-3
src/api/include/migraphx/migraphx.h
src/api/include/migraphx/migraphx.h
+7
-0
src/api/include/migraphx/migraphx.hpp
src/api/include/migraphx/migraphx.hpp
+150
-121
src/api/migraphx.py
src/api/migraphx.py
+4
-0
src/driver/marker_roctx.cpp
src/driver/marker_roctx.cpp
+1
-1
src/include/migraphx/filesystem.hpp
src/include/migraphx/filesystem.hpp
+4
-1
src/include/migraphx/op/gathernd.hpp
src/include/migraphx/op/gathernd.hpp
+131
-0
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+1
-0
src/include/migraphx/optional.hpp
src/include/migraphx/optional.hpp
+4
-1
No files found.
.clang-tidy
View file @
4957715b
...
...
@@ -4,7 +4,7 @@ CheckOptions:
- key: bugprone-unused-return-value.CheckedFunctions
value: '::std::async;::std::launder;::std::remove;::std::remove_if;::std::unique;::std::unique_ptr::release;::std::basic_string::empty;::std::vector::empty;::std::find;::std::find_if;::std::find_if_not;::std::all_of;::std::any_of;::std::none_of;::std::count;::std::count_if;::std::mismatch;::std::find_end;::std::find_first_of;::std::adjacent_find;::std::search;::std::search_n;::std::nth_element;::std::lower_bound;::std::upper_bound;::std::binary_search;::std::equal_range;::std::max;::std::max_element;::std::min;::std::min_element;::std::minmax;::std::minmax_element;::std::equal;::std::lexicographical_compare;::std::accumulate;::std::inner_product'
- key: cppcoreguidelines-macro-usage.AllowedRegexp
value: 'DEBUG|ASSERT|ASSUME|UNREACHABLE|FALLTHROUGH|STRINGIZE|_HAS_|_THROW|_REQUIRES|_DECLARE_|_VISIT_|_REGISTER_|_GENERATE_|_DETAIL_|_TIDY_|_MANAGE_PTR|_MATCHER|DEVICE_SHARED|_WORKAROUND_'
value: 'DEBUG|ASSERT|ASSUME|UNREACHABLE|FALLTHROUGH|
DEPRECATED|
STRINGIZE|_HAS_|_THROW|_REQUIRES|_DECLARE_|_VISIT_|_REGISTER_|_GENERATE_|_DETAIL_|_TIDY_|_MANAGE_PTR|_MATCHER|DEVICE_SHARED|_WORKAROUND_'
- key: modernize-loop-convert.MinConfidence
value: risky
- key: modernize-loop-convert.NamingStyle
...
...
.github/workflows/ci.yaml
View file @
4957715b
...
...
@@ -15,7 +15,8 @@ jobs:
steps
:
-
name
:
Free space
run
:
sudo rm -rf /usr/local/android /usr/share/dotnet /usr/local/share/boost /opt/ghc /usr/local/share/chrom* /usr/share/swift /usr/local/julia*
run
:
sudo rm -rf /usr/local/android /usr/share/dotnet /usr/local/share/boost /opt/ghc /usr/local/share/chrom* /usr/share/swift /usr/local/julia* /usr/local/lib/android
-
uses
:
actions/checkout@v2
# In this step, this action saves a list of existing images,
...
...
@@ -63,7 +64,7 @@ jobs:
steps
:
-
name
:
Free space
run
:
sudo rm -rf /usr/local/android /usr/share/dotnet /usr/local/share/boost /opt/ghc /usr/local/share/chrom* /usr/share/swift /usr/local/julia*
run
:
sudo rm -rf /usr/local/android /usr/share/dotnet /usr/local/share/boost /opt/ghc /usr/local/share/chrom* /usr/share/swift /usr/local/julia*
/usr/local/lib/android
-
uses
:
actions/checkout@v2
# In this step, this action saves a list of existing images,
...
...
@@ -108,7 +109,7 @@ jobs:
steps
:
-
name
:
Free space
run
:
sudo rm -rf /usr/local/android /usr/share/dotnet /usr/local/share/boost /opt/ghc /usr/local/share/chrom* /usr/share/swift /usr/local/julia*
run
:
sudo rm -rf /usr/local/android /usr/share/dotnet /usr/local/share/boost /opt/ghc /usr/local/share/chrom* /usr/share/swift /usr/local/julia*
/usr/local/lib/android
-
uses
:
actions/checkout@v2
# In this step, this action saves a list of existing images,
...
...
@@ -143,7 +144,7 @@ jobs:
steps
:
-
name
:
Free space
run
:
sudo rm -rf /usr/local/android /usr/share/dotnet /usr/local/share/boost /opt/ghc /usr/local/share/chrom* /usr/share/swift /usr/local/julia*
run
:
sudo rm -rf /usr/local/android /usr/share/dotnet /usr/local/share/boost /opt/ghc /usr/local/share/chrom* /usr/share/swift /usr/local/julia*
/usr/local/lib/android
-
uses
:
actions/checkout@v2
-
name
:
Set up Python
uses
:
actions/setup-python@v2
...
...
@@ -182,7 +183,7 @@ jobs:
steps
:
-
name
:
Free space
run
:
sudo rm -rf /usr/local/android /usr/share/dotnet /usr/local/share/boost /opt/ghc /usr/local/share/chrom* /usr/share/swift /usr/local/julia*
run
:
sudo rm -rf /usr/local/android /usr/share/dotnet /usr/local/share/boost /opt/ghc /usr/local/share/chrom* /usr/share/swift /usr/local/julia*
/usr/local/lib/android
-
uses
:
actions/checkout@v2
-
name
:
Set up Python
uses
:
actions/setup-python@v2
...
...
CMakeLists.txt
View file @
4957715b
...
...
@@ -42,7 +42,7 @@ find_package(nlohmann_json 3.8.0 REQUIRED)
include
(
ROCMSetupVersion
)
rocm_setup_version
(
VERSION 2.
2
)
rocm_setup_version
(
VERSION 2.
3
)
set
(
MIGRAPHX_SO_VERSION
${
PROJECT_VERSION_MAJOR
}
.
${
PROJECT_VERSION_MINOR
}
)
option
(
BUILD_SHARED_LIBS
"Build as a shared library"
ON
)
...
...
@@ -93,11 +93,14 @@ rocm_enable_clang_tidy(
modernize-*
performance-*
readability-*
-bugprone-signed-char-misuse
-bugprone-easily-swappable-parameters
-bugprone-implicit-widening-of-multiplication-result
-bugprone-macro-parentheses
-bugprone-signed-char-misuse
# Disable the aliased reserved identifiers
-cert-dcl37-c
-cert-dcl51-cpp
-cert-err33-c
-cert-str34-c
# Disable all alpha checks by default
-clang-analyzer-alpha*
...
...
@@ -127,6 +130,7 @@ rocm_enable_clang_tidy(
-cppcoreguidelines-pro-type-union-access
-cppcoreguidelines-pro-type-vararg
-cppcoreguidelines-special-member-functions
-cppcoreguidelines-virtual-class-destructor
-google-readability-*
-google-runtime-int
-google-runtime-references
...
...
@@ -144,8 +148,10 @@ rocm_enable_clang_tidy(
-readability-convert-member-functions-to-static
-readability-else-after-return
-readability-function-cognitive-complexity
-readability-identifier-length
-readability-named-parameter
-readability-redundant-string-init
-readability-suspicious-call-argument
-readability-uppercase-literal-suffix
-*-avoid-c-arrays
-*-explicit-constructor
...
...
@@ -178,7 +184,7 @@ rocm_enable_cppcheck(
style
performance
portability
SUPPRESS
SUPPRESS
ConfigurationNotChecked
unmatchedSuppression
unusedFunction
...
...
@@ -216,6 +222,7 @@ rocm_enable_cppcheck(
CPPCHECK=1
__device__=
__host__=
__global__=
)
enable_testing
()
...
...
Dockerfile
View file @
4957715b
FROM
ubuntu:
18
.04
FROM
ubuntu:
20
.04
ARG
PREFIX=/usr/local
...
...
@@ -6,7 +6,7 @@ ARG PREFIX=/usr/local
RUN
dpkg
--add-architecture
i386
# Add rocm repository
RUN
sh
-c
'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/
4.5
/ ubuntu main > /etc/apt/sources.list.d/rocm.list'
RUN
sh
-c
'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/
5.0.2
/ ubuntu main > /etc/apt/sources.list.d/rocm.list'
# Install dependencies
RUN
apt-get update
&&
DEBIAN_FRONTEND
=
noninteractive apt-get
install
-y
--allow-unauthenticated
\
...
...
@@ -16,16 +16,12 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
cmake
\
curl
\
doxygen
\
g++-5
\
g++-7
\
gdb
\
git
\
lcov
\
locales
\
pkg-config
\
python
\
python-dev
\
python-pip
\
python3
\
python3-dev
\
python3-pip
\
...
...
cmake/Embed.cmake
View file @
4957715b
...
...
@@ -94,5 +94,6 @@ function(add_embed_library EMBED_NAME)
generate_embed_source
(
${
EMBED_NAME
}
SRC
${
SRC_FILE
}
HEADER
${
HEADER_FILE
}
OBJECTS
${
OUTPUT_FILES
}
SYMBOLS
${
SYMBOLS
}
)
add_library
(
${
EMBED_NAME
}
STATIC
${
OUTPUT_FILES
}
"
${
SRC_FILE
}
"
)
target_include_directories
(
${
EMBED_NAME
}
PUBLIC
"
${
EMBED_DIR
}
/include"
)
target_compile_options
(
${
EMBED_NAME
}
PRIVATE -Wno-reserved-identifier
)
set_target_properties
(
${
EMBED_NAME
}
PROPERTIES POSITION_INDEPENDENT_CODE On
)
endfunction
()
cmake/EnableCompilerWarnings.cmake
View file @
4957715b
...
...
@@ -96,6 +96,7 @@ else()
-Wno-gnu-zero-variadic-macro-arguments
-Wno-missing-prototypes
-Wno-nested-anon-types
-Wno-option-ignored
-Wno-padded
-Wno-shorten-64-to-32
-Wno-sign-conversion
...
...
doc/src/reference/py.rst
View file @
4957715b
...
...
@@ -146,6 +146,13 @@ module
:param list[module] mod_args: optional list of module arguments to the operator.
:rtype instruction
.. py:method:: add_literal(data)
Adds constant or literal data of provided shape into the module from python buffer which includes numpy array.
:param py::buffer data: Python buffer or numpy array
:rtype instruction
.. py:method:: add_parameter(name, shape)
Adds a parameter to the module with provided name and shape.
...
...
examples/nlp/python_bert_squad/BERT-Squad.ipynb
View file @
4957715b
...
...
@@ -62,7 +62,7 @@
"metadata": {},
"outputs": [],
"source": [
"!wget -nc https://github.com/onnx/models/
raw/master
/text/machine_comprehension/bert-squad/model/bertsquad-10.onnx"
"!wget -nc https://github.com/onnx/models/
blob/main
/text/machine_comprehension/bert-squad/model/bertsquad-10.onnx"
]
},
{
...
...
examples/nlp/python_bert_squad/README.md
View file @
4957715b
...
...
@@ -23,7 +23,7 @@ unzip uncased_L-12_H-768_A-12.zip
```
5) Get BERT ONNX model (bertsquad-10.onnx):
```
wget https://github.com/onnx/models/
raw/master
/text/machine_comprehension/bert-squad/model/bertsquad-10.onnx
wget https://github.com/onnx/models/
blob/main
/text/machine_comprehension/bert-squad/model/bertsquad-10.onnx
```
6) Run the inference, it will compile and run the model on three questions and small data provided in
`inputs.json`
:
```
...
...
hip-clang.docker
View file @
4957715b
FROM
ubuntu:
18
.04
FROM
ubuntu:
20
.04
ARG
PREFIX=/usr/local
...
...
@@ -6,7 +6,7 @@ ARG PREFIX=/usr/local
RUN
dpkg
--add-architecture
i386
# Add rocm repository
RUN
sh
-c
'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/
4.5
/ ubuntu main > /etc/apt/sources.list.d/rocm.list'
RUN
sh
-c
'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/
5.0.2
/ ubuntu main > /etc/apt/sources.list.d/rocm.list'
# Install dependencies
RUN
apt-get update
&&
DEBIAN_FRONTEND
=
noninteractive apt-get
install
-y
--allow-unauthenticated
\
...
...
@@ -20,9 +20,6 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
git
\
lcov
\
pkg-config
\
python
\
python-dev
\
python-pip
\
python3
\
python3-dev
\
python3-pip
\
...
...
src/CMakeLists.txt
View file @
4957715b
...
...
@@ -109,6 +109,7 @@ register_migraphx_ops(
flatten
floor
gather
gathernd
get_tuple_elem
greater
gru
...
...
src/api/api.cpp
View file @
4957715b
...
...
@@ -401,7 +401,8 @@ extern "C" struct migraphx_instruction;
struct
migraphx_instruction
{
template
<
class
...
Ts
>
migraphx_instruction
(
Ts
&&
...
xs
)
:
object
(
std
::
forward
<
Ts
>
(
xs
)...)
migraphx_instruction
(
Ts
&&
...
xs
)
:
object
(
std
::
forward
<
Ts
>
(
xs
)...)
// NOLINT(readability-redundant-member-init)
{
}
migraphx
::
instruction_ref
object
;
...
...
@@ -411,7 +412,8 @@ extern "C" struct migraphx_instructions;
struct
migraphx_instructions
{
template
<
class
...
Ts
>
migraphx_instructions
(
Ts
&&
...
xs
)
:
object
(
std
::
forward
<
Ts
>
(
xs
)...)
migraphx_instructions
(
Ts
&&
...
xs
)
:
object
(
std
::
forward
<
Ts
>
(
xs
)...)
// NOLINT(readability-redundant-member-init)
{
}
std
::
vector
<
migraphx
::
instruction_ref
>
object
;
...
...
@@ -421,7 +423,8 @@ extern "C" struct migraphx_modules;
struct
migraphx_modules
{
template
<
class
...
Ts
>
migraphx_modules
(
Ts
&&
...
xs
)
:
object
(
std
::
forward
<
Ts
>
(
xs
)...)
migraphx_modules
(
Ts
&&
...
xs
)
:
object
(
std
::
forward
<
Ts
>
(
xs
)...)
// NOLINT(readability-redundant-member-init)
{
}
std
::
vector
<
migraphx
::
module
*>
object
;
...
...
@@ -1069,6 +1072,22 @@ migraphx_module_add_instruction_with_mod_args(migraphx_instruction_t* out,
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_module_add_literal
(
migraphx_instruction_t
*
out
,
migraphx_module_t
module
,
const_migraphx_shape_t
shape
,
const
char
*
buffer
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
if
(
module
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter module: Null pointer"
);
if
(
shape
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter shape: Null pointer"
);
*
out
=
allocate
<
migraphx_instruction_t
>
(
(
module
->
object
).
add_literal
((
shape
->
object
),
(
buffer
)));
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_module_add_parameter
(
migraphx_instruction_t
*
out
,
migraphx_module_t
module
,
const
char
*
name
,
...
...
@@ -1691,6 +1710,16 @@ extern "C" migraphx_status migraphx_context_finish(const_migraphx_context_t cont
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_context_get_queue
(
void
**
out
,
migraphx_context_t
context
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
if
(
context
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter context: Null pointer"
);
*
out
=
(
context
->
object
).
get_queue
().
unsafe_get
();
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_experimental_custom_op_destroy
(
migraphx_experimental_custom_op_t
experimental_custom_op
)
{
...
...
src/api/include/migraphx/migraphx.h
View file @
4957715b
...
...
@@ -258,6 +258,11 @@ migraphx_status migraphx_module_add_instruction_with_mod_args(migraphx_instructi
migraphx_instructions_t
args
,
migraphx_modules_t
module_refs
);
migraphx_status
migraphx_module_add_literal
(
migraphx_instruction_t
*
out
,
migraphx_module_t
module
,
const_migraphx_shape_t
shape
,
const
char
*
buffer
);
migraphx_status
migraphx_module_add_parameter
(
migraphx_instruction_t
*
out
,
migraphx_module_t
module
,
const
char
*
name
,
...
...
@@ -433,6 +438,8 @@ migraphx_status migraphx_quantize_int8(migraphx_program_t prog,
migraphx_status
migraphx_context_finish
(
const_migraphx_context_t
context
);
migraphx_status
migraphx_context_get_queue
(
void
**
out
,
migraphx_context_t
context
);
migraphx_status
migraphx_experimental_custom_op_destroy
(
migraphx_experimental_custom_op_t
experimental_custom_op
);
...
...
src/api/include/migraphx/migraphx.hpp
View file @
4957715b
...
...
@@ -15,6 +15,16 @@ namespace migraphx {
inline
namespace
api
{
// NOLINT
#endif
#ifdef __has_cpp_attribute
#if __has_cpp_attribute(deprecated)
#define MIGRAPHX_DEPRECATED(...) [[deprecated(__VA_ARGS__)]]
#endif
#endif
#ifndef MIGRAPHX_DEPRECATED
#define MIGRAPHX_DEPRECATED(...)
#endif
template
<
int
N
>
struct
rank
:
rank
<
N
-
1
>
{
...
...
@@ -99,34 +109,22 @@ struct iota_iterator
return
it
;
}
// TODO: operator->
reference
operator
*
()
const
{
return
(
*
f
)(
index
);
}
};
reference
operator
*
()
const
{
return
f
(
index
);
}
template
<
class
F
,
class
Iterator
>
inline
iota_iterator
<
F
,
Iterator
>
operator
+
(
iota_iterator
<
F
,
Iterator
>
x
,
iota_iterator
<
F
,
Iterator
>
y
)
{
return
iota_iterator
<
F
,
Iterator
>
(
x
.
index
+
y
.
index
,
x
.
f
);
}
friend
iota_iterator
operator
+
(
iota_iterator
x
,
iota_iterator
y
)
{
return
iota_iterator
(
x
.
index
+
y
.
index
,
x
.
f
);
}
template
<
class
F
,
class
Iterator
>
inline
iota_iterator
<
F
,
Iterator
>
operator
-
(
iota_iterator
<
F
,
Iterator
>
x
,
iota_iterator
<
F
,
Iterator
>
y
)
{
return
iota_iterator
<
F
,
Iterator
>
(
x
.
index
-
y
.
index
,
x
.
f
);
}
friend
iota_iterator
operator
-
(
iota_iterator
x
,
iota_iterator
y
)
{
return
iota_iterator
(
x
.
index
-
y
.
index
,
x
.
f
);
}
template
<
class
F
,
class
Iterator
>
inline
bool
operator
==
(
iota_iterator
<
F
,
Iterator
>
x
,
iota_iterator
<
F
,
Iterator
>
y
)
{
return
x
.
index
==
y
.
index
;
}
friend
bool
operator
==
(
iota_iterator
x
,
iota_iterator
y
)
{
return
x
.
index
==
y
.
index
;
}
template
<
class
F
,
class
Iterator
>
inline
bool
operator
!=
(
iota_iterator
<
F
,
Iterator
>
x
,
iota_iterator
<
F
,
Iterator
>
y
)
{
return
x
.
index
!=
y
.
index
;
}
friend
bool
operator
!=
(
iota_iterator
x
,
iota_iterator
y
)
{
return
x
.
index
!=
y
.
index
;
}
};
template
<
class
Derived
>
struct
array_base
...
...
@@ -136,8 +134,20 @@ struct array_base
template
<
class
T
>
using
value_type_t
=
decltype
(
std
::
declval
<
T
>
()[
0
]);
struct
iterator_read
{
const
Derived
*
self
;
template
<
class
D
=
Derived
>
value_type_t
<
D
>
operator
()(
size_t
pidx
)
const
{
return
(
*
self
)[
pidx
];
}
};
template
<
class
T
>
using
iterator_t
=
iota_iterator
<
typename
T
::
iterator_read
>
;
using
iterator_t
=
iota_iterator
<
iterator_read
>
;
bool
empty
()
const
{
return
derived
().
size
()
==
0
;
}
template
<
class
D
=
Derived
>
value_type_t
<
D
>
front
()
const
...
...
@@ -154,13 +164,13 @@ struct array_base
template
<
class
D
=
Derived
>
iterator_t
<
D
>
begin
()
const
{
return
{
0
,
{
derived
()
.
get_handle_ptr
()
}};
return
{
0
,
{
&
derived
()}};
}
template
<
class
D
=
Derived
>
iterator_t
<
D
>
end
()
const
{
return
{
derived
().
size
(),
{
derived
()
.
get_handle_ptr
()
}};
return
{
derived
().
size
(),
{
&
derived
()}};
}
};
...
...
@@ -200,9 +210,25 @@ struct borrow
{
};
template
<
class
T
>
struct
share
{
share
(
std
::
shared_ptr
<
T
>
p
)
:
ptr
(
std
::
move
(
p
))
{}
template
<
class
U
>
std
::
shared_ptr
<
U
>
alias
(
U
*
p
)
const
{
return
std
::
shared_ptr
<
U
>
{
ptr
,
p
};
}
private:
std
::
shared_ptr
<
T
>
ptr
;
};
template
<
class
Derived
,
class
T
,
class
D
,
D
Deleter
,
class
A
,
A
Assigner
>
struct
handle_base
:
handle_lookup
<
Derived
,
std
::
remove_cv_t
<
T
>>
{
using
handle_type
=
T
;
handle_base
()
:
m_handle
(
nullptr
)
{}
template
<
class
F
,
class
...
Ts
>
void
make_handle
(
F
f
,
Ts
&&
...
xs
)
...
...
@@ -231,6 +257,14 @@ struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>>
m_handle
=
std
::
shared_ptr
<
U
>
{
ptr
,
[](
U
*
)
{}};
}
template
<
class
U
,
class
V
>
void
set_handle
(
U
*
ptr
,
share
<
V
>
b
)
{
m_handle
=
std
::
shared_ptr
<
T
>
{
ptr
,
[
b
](
U
*
)
{}};
}
share
<
T
>
share_handle
()
const
{
return
{
m_handle
};
}
template
<
class
U
>
void
assign_to_handle
(
U
*
x
)
{
...
...
@@ -241,6 +275,17 @@ struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>>
std
::
shared_ptr
<
T
>
m_handle
;
};
// NOLINTNEXTLINE
#define MIGRAPHX_HANDLE_CONSTRUCTOR(name) \
template <class HandleType, \
class Lifetime, \
class = \
typename std::enable_if<std::is_convertible<HandleType*, handle_type*>{}>::type> \
name(HandleType* p, Lifetime lifetime) \
{ \
this->set_handle(p, std::move(lifetime)); \
}
template
<
class
Base
>
struct
interface_base
:
Base
{
...
...
@@ -269,6 +314,7 @@ struct interface_base : Base
T
**
y
=
reinterpret_cast
<
T
**>
(
out
);
T
*
x
=
reinterpret_cast
<
T
*>
(
input
);
assert
(
x
!=
nullptr
and
y
!=
nullptr
and
*
y
==
nullptr
);
// cppcheck-suppress useSmartPointer
*
y
=
new
T
(
*
x
);
// NOLINT
});
};
...
...
@@ -294,6 +340,7 @@ struct interface_base : Base
template
<
class
T
,
class
Setter
,
class
F
>
void
set_auto_fp
(
Setter
setter
,
F
f
)
{
// cppcheck-suppress constParameter
return
set_fp
<
T
>
(
setter
,
[
=
](
T
&
obj
,
auto
out
,
auto
...
xs
)
{
auto_invoke
(
f
,
out
,
obj
,
auto_convert_param
(
rank
<
2
>
{},
xs
)...);
});
...
...
@@ -398,11 +445,10 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
{
shape
()
{}
MIGRAPHX_DEPRECATED
(
"Contructor without lifetime annotation is deprecated."
)
shape
(
const
migraphx_shape
*
p
)
{
this
->
set_handle
(
p
,
borrow
{});
}
shape
(
migraphx_shape
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
shape
(
migraphx_shape
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
shape
);
/// Construct a scalar shape
shape
(
migraphx_shape_datatype_t
type
)
...
...
@@ -479,10 +525,9 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
{
argument
()
{}
argument
(
migraphx_argument
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
argument
(
migraphx_argument
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
argument
);
MIGRAPHX_DEPRECATED
(
"Contructor without lifetime annotation is deprecated."
)
argument
(
const
migraphx_argument
*
p
)
{
this
->
set_handle
(
p
,
borrow
{});
}
argument
(
shape
pshape
,
void
*
pbuffer
)
...
...
@@ -494,7 +539,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
{
const_migraphx_shape_t
pout
;
call
(
&
migraphx_argument_shape
,
&
pout
,
this
->
get_handle_ptr
());
return
{
pout
};
return
{
pout
,
this
->
share_handle
()
};
}
char
*
data
()
const
...
...
@@ -526,9 +571,7 @@ struct target : MIGRAPHX_HANDLE_BASE(target)
{
target
()
{}
target
(
migraphx_target
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
target
(
migraphx_target
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
target
);
/// Construct a target from its name
target
(
const
char
*
name
)
{
this
->
make_handle
(
&
migraphx_target_create
,
name
);
}
...
...
@@ -538,15 +581,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
{
program_parameter_shapes
()
{}
program_parameter_shapes
(
migraphx_program_parameter_shapes
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
program_parameter_shapes
(
migraphx_program_parameter_shapes
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
program_parameter_shapes
);
size_t
size
()
const
{
...
...
@@ -559,7 +594,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
{
const_migraphx_shape_t
pout
;
call
(
&
migraphx_program_parameter_shapes_get
,
&
pout
,
this
->
get_handle_ptr
(),
pname
);
return
{
pout
};
return
{
pout
,
this
->
share_handle
()
};
}
std
::
vector
<
const
char
*>
names
()
const
...
...
@@ -576,10 +611,9 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
/// A class to construct the inputs parameters for a program
struct
program_parameters
:
MIGRAPHX_HANDLE_BASE
(
program_parameters
)
{
program_parameters
(
migraphx_program_parameters
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
program_parameters
(
migraphx_program_parameters
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
program_parameters
);
MIGRAPHX_DEPRECATED
(
"Contructor without lifetime annotation is deprecated."
)
program_parameters
(
migraphx_program_parameters
*
p
)
{
this
->
set_handle
(
p
,
borrow
{});
}
program_parameters
()
{
this
->
make_handle
(
&
migraphx_program_parameters_create
);
}
...
...
@@ -604,9 +638,7 @@ struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters)
struct
arguments
:
MIGRAPHX_HANDLE_BASE
(
arguments
),
array_base
<
arguments
>
{
arguments
(
migraphx_arguments
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
arguments
(
migraphx_arguments
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
arguments
);
size_t
size
()
const
{
...
...
@@ -619,27 +651,13 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments>
{
const_migraphx_argument_t
pout
;
call
(
&
migraphx_arguments_get
,
&
pout
,
this
->
get_handle_ptr
(),
pidx
);
return
{
pout
};
return
{
pout
,
this
->
share_handle
()
};
}
struct
iterator_read
{
migraphx_arguments
*
self
;
argument
operator
()(
size_t
pidx
)
const
{
const_migraphx_argument_t
pout
;
call
(
&
migraphx_arguments_get
,
&
pout
,
self
,
pidx
);
return
{
pout
};
}
};
};
struct
shapes
:
MIGRAPHX_HANDLE_BASE
(
shapes
),
array_base
<
shapes
>
{
shapes
(
migraphx_shapes
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
shapes
(
migraphx_shapes
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
shapes
);
size_t
size
()
const
{
...
...
@@ -652,26 +670,13 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
{
const_migraphx_shape_t
pout
;
call
(
&
migraphx_shapes_get
,
&
pout
,
this
->
get_handle_ptr
(),
pidx
);
return
{
pout
};
return
{
pout
,
this
->
share_handle
()
};
}
struct
iterator_read
{
migraphx_shapes
*
self
;
shape
operator
()(
size_t
pidx
)
const
{
const_migraphx_shape_t
pout
;
call
(
&
migraphx_shapes_get
,
&
pout
,
self
,
pidx
);
return
{
pout
};
}
};
};
struct
operation
:
MIGRAPHX_HANDLE_BASE
(
operation
)
{
operation
(
migraphx_operation
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
operation
(
migraphx_operation
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
operation
);
template
<
class
...
Ts
>
operation
(
const
char
*
name
,
const
char
*
attributes
=
nullptr
,
Ts
...
xs
)
...
...
@@ -689,15 +694,12 @@ struct operation : MIGRAPHX_HANDLE_BASE(operation)
struct
instruction
:
MIGRAPHX_CONST_HANDLE_BASE
(
instruction
)
{
instruction
(
migraphx_instruction
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
instruction
);
};
struct
instructions
:
MIGRAPHX_HANDLE_BASE
(
instructions
)
{
instructions
(
migraphx_instructions
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
instructions
(
migraphx_instructions
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
instructions
);
template
<
class
...
Ts
>
instructions
(
Ts
...
xs
)
...
...
@@ -711,33 +713,36 @@ struct module;
struct
modules
:
MIGRAPHX_HANDLE_BASE
(
modules
)
{
modules
(
migraphx_modules
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
modules
(
migraphx_modules
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
modules
);
template
<
class
...
Ts
>
modules
(
Ts
...
xs
)
{
std
::
array
<
migraphx_module_t
,
sizeof
...(
Ts
)
>
a
=
{
xs
.
mm
...};
std
::
array
<
migraphx_module_t
,
sizeof
...(
Ts
)
>
a
=
{
xs
.
get_handle_ptr
()
...};
this
->
make_handle
(
&
migraphx_modules_create
,
a
.
data
(),
a
.
size
());
}
};
struct
module
{
migraphx_module_t
mm
;
MIGRAPHX_DEPRECATED
(
"Constructor without lifetime annotation is deprecated."
)
module
(
migraphx_module
*
m
)
:
mm
(
std
::
shared_ptr
<
migraphx_module
*>
(),
m
)
{}
module
(
const
migraphx_module
_t
&
m
)
:
mm
(
m
)
{}
module
(
migraphx_module
*
m
,
borrow
)
:
mm
(
std
::
shared_ptr
<
migraphx_module
*>
(),
m
)
{}
void
print
()
const
{
call
(
&
migraphx_module_print
,
mm
);
}
template
<
class
T
>
module
(
migraphx_module
*
m
,
share
<
T
>
b
)
:
mm
(
b
.
alias
(
m
))
{
}
void
print
()
const
{
call
(
&
migraphx_module_print
,
mm
.
get
());
}
instruction
add_instruction
(
const
migraphx
::
operation
&
op
,
const
migraphx
::
instructions
&
args
)
{
migraphx_instruction_t
op_ins
;
call
(
&
migraphx_module_add_instruction
,
&
op_ins
,
mm
,
mm
.
get
()
,
op
.
get_handle_ptr
(),
args
.
get_handle_ptr
());
return
instruction
(
op_ins
,
own
{});
...
...
@@ -750,40 +755,72 @@ struct module
migraphx_instruction_t
op_ins
;
call
(
&
migraphx_module_add_instruction_with_mod_args
,
&
op_ins
,
mm
,
mm
.
get
()
,
op
.
get_handle_ptr
(),
args
.
get_handle_ptr
(),
module_args
.
get_handle_ptr
());
return
instruction
(
op_ins
,
own
{});
}
template
<
typename
T
>
instruction
add_literal
(
const
migraphx
::
shape
&
s
,
T
*
buffer
)
{
migraphx_instruction_t
literal_ins
;
const
auto
*
buffer_ptr
=
reinterpret_cast
<
const
char
*>
(
buffer
);
call
(
&
migraphx_module_add_literal
,
&
literal_ins
,
mm
.
get
(),
s
.
get_handle_ptr
(),
buffer_ptr
);
return
instruction
(
literal_ins
,
own
{});
}
instruction
add_parameter
(
const
std
::
string
&
name
,
shape
s
)
{
migraphx_instruction_t
param_ins
;
call
(
&
migraphx_module_add_parameter
,
&
param_ins
,
mm
,
name
.
c_str
(),
s
.
get_handle_ptr
());
call
(
&
migraphx_module_add_parameter
,
&
param_ins
,
mm
.
get
(),
name
.
c_str
(),
s
.
get_handle_ptr
());
return
instruction
(
param_ins
,
own
{});
}
instruction
add_return
(
const
migraphx
::
instructions
&
args
)
{
migraphx_instruction_t
ret_ins
;
call
(
&
migraphx_module_add_return
,
&
ret_ins
,
mm
,
args
.
get_handle_ptr
());
call
(
&
migraphx_module_add_return
,
&
ret_ins
,
mm
.
get
()
,
args
.
get_handle_ptr
());
return
instruction
(
ret_ins
,
own
{});
}
migraphx_module_t
get_handle_ptr
()
const
{
return
mm
.
get
();
}
private:
std
::
shared_ptr
<
migraphx_module
>
mm
;
};
struct
context
{
migraphx_context
_t
ctx
;
context
(
migraphx_context
*
p
,
borrow
)
:
ctx
(
std
::
shared_ptr
<
migraphx_context
*>
(),
p
)
{}
void
finish
()
const
{
call
(
&
migraphx_context_finish
,
ctx
);
}
template
<
class
T
>
context
(
migraphx_context
*
p
,
share
<
T
>
b
)
:
ctx
(
b
.
alias
(
p
))
{
}
void
finish
()
const
{
call
(
&
migraphx_context_finish
,
ctx
.
get
());
}
template
<
class
T
>
T
get_queue
()
{
void
*
out
;
call
(
&
migraphx_context_get_queue
,
&
out
,
ctx
.
get
());
// TODO: check type here
return
reinterpret_cast
<
T
>
(
out
);
}
private:
std
::
shared_ptr
<
migraphx_context
>
ctx
;
};
struct
compile_options
:
MIGRAPHX_HANDLE_BASE
(
compile_options
)
{
compile_options
()
{
this
->
make_handle
(
&
migraphx_compile_options_create
);
}
compile_options
(
migraphx_compile_options
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
());
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
compile_options
);
/// For targets with offloaded memory(such as the gpu), this will insert
/// instructions during compilation to copy the input parameters to the
...
...
@@ -807,9 +844,7 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
{
program
()
{
this
->
make_handle
(
&
migraphx_program_create
);
}
program
(
migraphx_program
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
program
(
migraphx_program
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
program
);
/// Compile the program for a specific target to be ran on
void
compile
(
const
target
&
ptarget
,
const
compile_options
&
poptions
)
const
...
...
@@ -872,21 +907,21 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
{
migraphx_module_t
p_modu
;
call
(
&
migraphx_program_get_main_module
,
&
p_modu
,
this
->
get_handle_ptr
());
return
module
{
p_modu
};
return
module
{
p_modu
,
this
->
share_handle
()
};
}
context
experimental_get_context
()
{
migraphx_context_t
ctx
;
call
(
&
migraphx_program_experimental_get_context
,
&
ctx
,
this
->
get_handle_ptr
());
return
context
{
ctx
};
return
context
{
ctx
,
this
->
share_handle
()
};
}
module
create_module
(
const
std
::
string
&
name
)
{
migraphx_module_t
p_modu
;
call
(
&
migraphx_program_create_module
,
&
p_modu
,
this
->
get_handle_ptr
(),
name
.
data
());
return
module
{
p_modu
};
return
module
{
p_modu
,
this
->
share_handle
()
};
}
friend
bool
operator
!=
(
const
program
&
px
,
const
program
&
py
)
{
return
!
(
px
==
py
);
}
...
...
@@ -895,10 +930,9 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
// options for migraphx file format options
struct
file_options
:
MIGRAPHX_HANDLE_BASE
(
file_options
)
{
MIGRAPHX_HANDLE_CONSTRUCTOR
(
file_options
);
file_options
()
{
this
->
make_handle
(
&
migraphx_file_options_create
);
}
file_options
(
migraphx_file_options
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
());
}
// set file format
void
set_file_format
(
const
char
*
format
)
{
...
...
@@ -938,7 +972,7 @@ struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options)
{
onnx_options
()
{
this
->
make_handle
(
&
migraphx_onnx_options_create
);
}
onnx_options
(
migraphx_onnx_options
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
onnx_options
);
/// Make onnx parser treat an inputs with a certain dimensions
void
set_input_parameter_shape
(
const
std
::
string
&
name
,
std
::
vector
<
std
::
size_t
>
dim
)
...
...
@@ -1020,7 +1054,7 @@ struct tf_options : MIGRAPHX_HANDLE_BASE(tf_options)
{
tf_options
()
{
this
->
make_handle
(
&
migraphx_tf_options_create
);
}
tf_options
(
migraphx_tf_options
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
tf_options
);
/// Make tf parser treat an inputs with a certain dimensions
void
set_input_parameter_shape
(
const
std
::
string
&
name
,
std
::
vector
<
std
::
size_t
>
dim
)
...
...
@@ -1073,7 +1107,7 @@ struct quantize_op_names : MIGRAPHX_HANDLE_BASE(quantize_op_names)
{
quantize_op_names
()
{
this
->
make_handle
(
&
migraphx_quantize_op_names_create
);
}
quantize_op_names
(
migraphx_quantize_op_names
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
quantize_op_names
);
void
add
(
const
std
::
string
&
name
)
{
...
...
@@ -1098,12 +1132,7 @@ struct quantize_int8_options : MIGRAPHX_HANDLE_BASE(quantize_int8_options)
{
quantize_int8_options
()
{
this
->
make_handle
(
&
migraphx_quantize_int8_options_create
);
}
quantize_int8_options
(
migraphx_quantize_int8_options
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
quantize_int8_options
(
migraphx_quantize_int8_options
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
quantize_int8_options
);
/// Add an operator that should be quantized
void
add_op_name
(
const
std
::
string
&
name
)
...
...
src/api/migraphx.py
View file @
4957715b
...
...
@@ -212,6 +212,9 @@ def module(h):
module_refs
=
'std::vector<migraphx::module*>'
),
fname
=
'add_instruction'
,
returns
=
'migraphx::instruction_ref'
)
h
.
method
(
'add_literal'
,
api
.
params
(
shape
=
'const migraphx::shape&'
,
buffer
=
'const char*'
),
returns
=
'migraphx::instruction_ref'
)
h
.
method
(
'add_parameter'
,
api
.
params
(
name
=
'const char*'
,
shape
=
'const migraphx::shape&'
),
returns
=
'migraphx::instruction_ref'
)
...
...
@@ -403,6 +406,7 @@ api.add_function('migraphx_quantize_int8',
@
auto_handle
(
ref
=
True
)
def
context
(
h
):
h
.
method
(
'finish'
,
const
=
True
)
h
.
method
(
'get_queue'
,
returns
=
'void*'
,
fname
=
'get_queue().unsafe_get'
)
@
api
.
interface
(
'migraphx_experimental_custom_op'
,
...
...
src/driver/marker_roctx.cpp
View file @
4957715b
...
...
@@ -17,7 +17,7 @@ class marker_roctx
std
::
function
<
int
(
const
char
*
)
>
sym_roctx_range_push
;
std
::
function
<
int
()
>
sym_roctx_range_pop
;
uint64_t
range_id
;
uint64_t
range_id
=
0
;
public:
marker_roctx
()
...
...
src/include/migraphx/filesystem.hpp
View file @
4957715b
...
...
@@ -3,7 +3,10 @@
#include <migraphx/config.hpp>
#if defined(__has_include) && !defined(CPPCHECK)
#if defined(CPPCHECK)
#define MIGRAPHX_HAS_FILESYSTEM 1
#define MIGRAPHX_HAS_FILESYSTEM_TS 1
#elif defined(__has_include)
#if __has_include(<filesystem>) && __cplusplus >= 201703L
#define MIGRAPHX_HAS_FILESYSTEM 1
#else
...
...
src/include/migraphx/op/gathernd.hpp
0 → 100644
View file @
4957715b
#ifndef MIGRAPHX_GUARD_OPERATORS_GATHERND_HPP
#define MIGRAPHX_GUARD_OPERATORS_GATHERND_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/par_for.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
gathernd
{
int
batch_dims
=
0
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
batch_dims
,
"batch_dims"
));
}
std
::
string
name
()
const
{
return
"gathernd"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
2
);
auto
r
=
inputs
.
front
().
lens
().
size
();
auto
q
=
inputs
.
back
().
lens
().
size
();
auto
k
=
inputs
.
back
().
lens
().
back
();
if
(
k
>
r
-
batch_dims
)
{
MIGRAPHX_THROW
(
"GATHERND: Indices of length "
+
std
::
to_string
(
k
)
+
" cannot be used to access data of rank "
+
std
::
to_string
(
r
-
batch_dims
));
}
auto
indices_lens_iter
=
inputs
.
back
().
lens
().
begin
();
auto
output_lens_size
=
q
+
r
-
k
-
batch_dims
-
1
;
std
::
vector
<
std
::
size_t
>
output_lens
(
output_lens_size
);
std
::
copy
(
indices_lens_iter
,
indices_lens_iter
+
(
q
-
1
),
output_lens
.
begin
());
if
(
k
<
r
-
batch_dims
)
{
auto
data_lens
=
inputs
.
front
().
lens
();
std
::
copy
(
data_lens
.
begin
()
+
batch_dims
+
k
,
data_lens
.
end
(),
output_lens
.
begin
()
+
q
-
1
);
}
shape
output_shape
{
inputs
.
front
().
type
(),
output_lens
};
return
output_shape
;
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
data
)
{
args
[
1
].
visit
([
&
](
auto
indices
)
{
auto
indices_shape
=
indices
.
get_shape
();
auto
indices_shape_lens
=
indices_shape
.
lens
();
auto
data_shape
=
data
.
get_shape
();
auto
data_shape_lens
=
data_shape
.
lens
();
auto
k
=
indices_shape
.
lens
().
back
();
const
auto
num_slice_dims
=
k
;
std
::
size_t
num_slices
=
std
::
accumulate
(
indices_shape_lens
.
begin
(),
indices_shape_lens
.
end
()
-
1
,
1
,
std
::
multiplies
<
std
::
size_t
>
());
std
::
size_t
slice_size
=
std
::
accumulate
(
data_shape_lens
.
begin
()
+
k
+
batch_dims
,
data_shape_lens
.
end
(),
1
,
std
::
multiplies
<
std
::
size_t
>
());
std
::
size_t
num_batches
=
std
::
accumulate
(
data_shape_lens
.
begin
(),
data_shape_lens
.
begin
()
+
batch_dims
,
1
,
std
::
multiplies
<
std
::
size_t
>
());
std
::
size_t
data_batch_stride
=
std
::
accumulate
(
data_shape_lens
.
begin
()
+
batch_dims
,
data_shape_lens
.
end
(),
1
,
std
::
multiplies
<
std
::
size_t
>
());
auto
num_slices_per_batch
=
num_slices
/
num_batches
;
std
::
vector
<
std
::
size_t
>
sizes_from_slice_dims
(
num_slice_dims
);
{
auto
running_product
=
slice_size
;
for
(
std
::
size_t
i
=
0
;
i
<
num_slice_dims
;
++
i
)
{
sizes_from_slice_dims
[
num_slice_dims
-
1
-
i
]
=
running_product
;
running_product
*=
data_shape_lens
[
batch_dims
+
num_slice_dims
-
1
-
i
];
}
}
std
::
vector
<
std
::
size_t
>
input_slice_offsets
(
num_slices
);
par_for
(
num_slices
,
[
&
](
const
auto
i
)
{
std
::
size_t
batch_idx
=
i
/
num_slices_per_batch
;
auto
slice_indices
=
indices
.
begin
()
+
(
i
*
num_slice_dims
);
std
::
size_t
relative_slice_offset
=
0
;
for
(
size_t
dim_idx
=
0
;
dim_idx
<
num_slice_dims
;
++
dim_idx
)
{
int64_t
index
=
*
(
slice_indices
+
dim_idx
);
const
std
::
size_t
input_dim_idx
=
batch_dims
+
dim_idx
;
const
auto
input_dim
=
data_shape_lens
[
input_dim_idx
];
if
(
index
<
-
static_cast
<
int64_t
>
(
input_dim
)
or
index
>=
static_cast
<
int64_t
>
(
input_dim
))
MIGRAPHX_THROW
(
"GatherND: index "
+
std
::
to_string
(
index
)
+
" is out of bounds for dim of len "
+
std
::
to_string
(
input_dim
));
if
(
index
<
0
)
index
+=
input_dim
;
relative_slice_offset
+=
index
*
sizes_from_slice_dims
[
dim_idx
];
}
input_slice_offsets
[
i
]
=
(
batch_idx
*
data_batch_stride
)
+
relative_slice_offset
;
});
par_for
(
num_slices
*
slice_size
,
[
&
](
const
auto
i
)
{
auto
slice_offset
=
input_slice_offsets
[
i
/
slice_size
];
output
[
i
]
=
data
[
slice_offset
+
i
%
slice_size
];
});
});
});
return
result
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/operators.hpp
View file @
4957715b
...
...
@@ -35,6 +35,7 @@
#include <migraphx/op/flatten.hpp>
#include <migraphx/op/floor.hpp>
#include <migraphx/op/gather.hpp>
#include <migraphx/op/gathernd.hpp>
#include <migraphx/op/get_tuple_elem.hpp>
#include <migraphx/op/greater.hpp>
#include <migraphx/op/gru.hpp>
...
...
src/include/migraphx/optional.hpp
View file @
4957715b
...
...
@@ -3,7 +3,10 @@
#include <migraphx/config.hpp>
#if defined(__has_include) && !defined(CPPCHECK)
#if defined(CPPCHECK)
#define MIGRAPHX_HAS_OPTIONAL 1
#define MIGRAPHX_HAS_OPTIONAL_TS 1
#elif defined(__has_include)
#if __has_include(<optional>) && __cplusplus >= 201703L
#define MIGRAPHX_HAS_OPTIONAL 1
#else
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment