Unverified Commit 40fbef9b authored by Ted Themistokleous's avatar Ted Themistokleous Committed by GitHub
Browse files

Merge branch 'develop' into threaded_nms

parents d164b151 aeb9f78c
...@@ -82,6 +82,10 @@ Print out program in text format. ...@@ -82,6 +82,10 @@ Print out program in text format.
Print out program in binary format. Print out program in binary format.
.. option:: --py
Print out program using python API.
.. option:: --output, -o [std::string] .. option:: --output, -o [std::string]
Output to file. Output to file.
......
...@@ -3,18 +3,10 @@ ...@@ -3,18 +3,10 @@
You can adapt this file completely to your liking, but it should at least You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive. contain the root `toctree` directive.
Welcome to AMD MIGraphX's documentation! AMD MIGraphX documentation
======================================== ==========================
.. toctree::
:maxdepth: 3
:caption: Contents:
py_user_guide
cpp_user_guide
driver
contributor_guide
AMD MIGraphX is AMD's graph inference engine that accelerates machine learning model inference.
Indices and tables Indices and tables
================== ==================
......
...@@ -6,9 +6,9 @@ Python Reference ...@@ -6,9 +6,9 @@ Python Reference
shape shape
----- -----
.. py:class:: shape(type, lens, strides=None) .. py:class:: shape(type, lens, strides=None, dyn_dims)
Describes the shape of a tensor. This includes size, layout, and data type/ Describes the shape of a tensor. This includes size, layout, and data type. Can be a dynamic shape by using dyn_dims.
.. py:method:: type() .. py:method:: type()
...@@ -34,6 +34,12 @@ shape ...@@ -34,6 +34,12 @@ shape
:rtype: int :rtype: int
.. py:method:: dyn_dims()
The dynamic dimensions of the shape.
:rtype: list[dynamic_dimension]
.. py:method:: bytes() .. py:method:: bytes()
The number of bytes the shape uses. The number of bytes the shape uses.
...@@ -46,6 +52,12 @@ shape ...@@ -46,6 +52,12 @@ shape
:rtype: int :rtype: int
.. py:method:: ndim()
The number of dimensions for the shape.
:rtype: int
.. py:method:: packed() .. py:method:: packed()
Returns true if the shape is packed. Returns true if the shape is packed.
...@@ -64,6 +76,12 @@ shape ...@@ -64,6 +76,12 @@ shape
:rtype: bool :rtype: bool
.. py:method:: dynamic()
Returns true if the shape is dynamic.
:rtype: bool
.. py:method:: standard() .. py:method:: standard()
Returns true if the shape is a standard shape. That is, the shape is both packed and not transposed. Returns true if the shape is a standard shape. That is, the shape is both packed and not transposed.
...@@ -76,6 +94,18 @@ shape ...@@ -76,6 +94,18 @@ shape
:rtype: bool :rtype: bool
dynamic_dimension
--------
.. py:class:: dynamic_dimension(min, max, optimals)
Construct a dynamic_dimension from a minimum, a maximum, and optionally a set of optimals.
.. py:method:: is_fixed()
Returns true if the dynamic_dimension is fixed.
:rtype : int
argument argument
-------- --------
...@@ -121,6 +151,15 @@ argument ...@@ -121,6 +151,15 @@ argument
:rtype: argument :rtype: argument
.. py:function:: create_argument(s, values)
Create an argument of shape s with a set of values.
:param shape s: Shape of argument to create.
:param list values: Values to put in the argument. Must be the same number of elements as the shape.
:rtype: argument
.. py:function:: argument_from_pointer(shape, address) .. py:function:: argument_from_pointer(shape, address)
Create argument from data stored in given address without copy. Create argument from data stored in given address without copy.
...@@ -292,8 +331,10 @@ parse_onnx ...@@ -292,8 +331,10 @@ parse_onnx
Load and parse an onnx file. Load and parse an onnx file.
:param str filename: Path to file. :param str filename: Path to file.
:param str default_dim_value: default batch size to use (if not specified in onnx file). :param str default_dim_value: default dimension to use (if not specified in onnx file).
:param dynamic_dimension default_dyn_dim_value: default dynamic_dimension value to use.
:param str map_input_dims: Explicitly specify the dims of an input. :param str map_input_dims: Explicitly specify the dims of an input.
:param list[dynamic_dimension] map_dyn_input_dims: Explicitly specify the dynamic_dimensions of an input.
:param str skip_unknown_operators: Continue parsing onnx file if an unknown operator is found. :param str skip_unknown_operators: Continue parsing onnx file if an unknown operator is found.
:param str print_program_on_error: Print program if an error occurs. :param str print_program_on_error: Print program if an error occurs.
:param int max_loop_iterations: Maximum iteration number for the loop operator. :param int max_loop_iterations: Maximum iteration number for the loop operator.
......
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
cmake_minimum_required(VERSION 3.5)
project (cpp_dynamic_batch)
set (CMAKE_CXX_STANDARD 14)
set (EXAMPLE dynamic_batch)
list (APPEND CMAKE_PREFIX_PATH /opt/rocm)
find_package (migraphx)
message("source file: " ${EXAMPLE}.cpp " ---> bin: " ${EXAMPLE})
add_executable(${EXAMPLE} ${EXAMPLE}.cpp)
target_link_libraries(${EXAMPLE} migraphx::c)
# Running ONNX model with dynamic batch
## Description
This examples demonstrates how to run a graph program with dynamic batch using the MIGraphX C++ API.
## Creating dynamic dimension objects
`dynamic_dimension` objects are used in MIGraphX to specify a range of dimension values from a minimum value to a maximum value and optimal values that the tensor can be at model evaluation time.
A dynamic shape is defined by a list of `dynamic_dimensions` while a static shape only has fixed dimension values.
For example, a `dynamic_dimension` with `{min:1, max:10, optimals:{1, 4, 10}}` means that the dimension can be any value from 1 through 10 with the optimal values being 1, 4, and 10.
Supplied optimal values may allow MIGraphX to optimize the program for those specific shapes.
A fixed `dynamic_dimension` can be specified by setting the `min` and `max` to the same value (ex. `{min:3, max:3}`).
A dynamic shape specified solely by fixed `dynamic_dimension` objects will be converted to a static shape during parsing.
This can be useful for setting a static shape using the `set_dyn_input_parameter_shape()` method discussed later in this document.
## Parsing
ONNX graphs [ONNX](https://onnx.ai/get-started.html) can be parsed by MIGraphX to create a runnable program with dynamic batch sizes.
The dynamic batch range must be specified by a `dynamic_dimension` object.
One method to set the `dynamic_dimension` object works for ONNX files that only have symbolic variables for the batch dimensions:
```
migraphx::program p;
migraphx::onnx_options options;
options.set_default_dyn_dim_value(migraphx::dynamic_dimension{1, 4, {2, 4}});
p = parse_onnx(input_file, options);
```
Another option that can run any ONNX model with dynamic batch sizes uses the dynamic input map where the entire shape of the input parameter is supplied:
```
migraphx::program p;
migraphx::onnx_options options;
migraphx::dynamic_dimensions dyn_dims = {migraphx::dynamic_dimension{1, 4, {2, 4}},
migraphx::dynamic_dimension{3, 3},
migraphx::dynamic_dimension{4, 4},
migraphx::dynamic_dimension{4, 4}};
options.set_dyn_input_parameter_shape("input", dyn_dims);
p = parse_onnx(input_file, options);
```
## Compiling
Currently the MIGraphX C/C++ API requires that `offload_copy` be enabled for compiling dynamic batch programs.
Here is a snippet of compiling a model with `offload_copy` enabled:
```
migraphx::compile_options c_options;
c_options.set_offload_copy();
p.compile(migraphx::target("gpu"), c_options);
```
where `p` is the `migraphx::program`.
## Saving and Loading
A dynamic batch MIGraphX program can be saved and loaded to/from a MXR file the same way as a fully static shape program.
## Executing the dynamic batch model
The compiled dynamic batch model can be executed the same way as a static model by supplying the input data as `arguments` in a `program_parameters` object.
## Running the Example
Your ROCm installation could be installed in a location other than the one specified in the CMakeLists.txt.
You can set `LD_LIBRARY_PATH` or `CMAKE_PREFIX_PATH` to that location so that this program can still build.
The provided example is [`dynamic_batch.cpp`](./dynamic_batch.cpp)
To compile and run the example from this directory:
```
$ mkdir build
$ cd build
$ cmake ..
$ make
```
There will now be an executable named `dynamic_batch` with the following usage:
```
$ ./dynamic_batch
```
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <iostream>
#include <fstream>
#include <vector>
#include <string>
#include <algorithm>
// MIGraphX C++ API
#include <migraphx/migraphx.hpp>
int main(int argc, char** argv)
{
migraphx::onnx_options o_options;
migraphx::dynamic_dimensions dyn_dims = {migraphx::dynamic_dimension{1, 4, {2, 4}},
migraphx::dynamic_dimension{3, 3},
migraphx::dynamic_dimension{4, 4},
migraphx::dynamic_dimension{5, 5}};
o_options.set_dyn_input_parameter_shape("0", dyn_dims);
auto p = migraphx::parse_onnx("../add_scalar_test.onnx", o_options);
migraphx::compile_options c_options;
c_options.set_offload_copy();
p.compile(migraphx::target("gpu"), c_options);
// batch size = 2
std::vector<uint8_t> a(2 * 3 * 4 * 5, 3);
std::vector<uint8_t> b = {2};
migraphx::program_parameters pp;
migraphx::shape s = migraphx::shape(migraphx_shape_uint8_type, {2, 3, 4, 5});
pp.add("0", migraphx::argument(s, a.data()));
pp.add("1", migraphx::argument(migraphx::shape(migraphx_shape_uint8_type, {1}, {0}), b.data()));
auto outputs = p.eval(pp);
auto result = outputs[0];
std::vector<uint8_t> c(2 * 3 * 4 * 5, 5);
if(bool{result == migraphx::argument(s, c.data())})
{
std::cout << "Successfully executed dynamic batch add\n";
}
else
{
std::cout << "Failed dynamic batch add\n";
}
return 0;
}
...@@ -53,7 +53,6 @@ See below for a comprehensive list of commands and option arguments, as well as ...@@ -53,7 +53,6 @@ See below for a comprehensive list of commands and option arguments, as well as
| --enable-offload-copy | Enable implicit offload copying | | --enable-offload-copy | Enable implicit offload copying |
| --disable-fast-math | Disable fast math optimization | | --disable-fast-math | Disable fast math optimization |
| --exhaustive-tune | Enable exhaustive search to find fastest kernel | | --exhaustive-tune | Enable exhaustive search to find fastest kernel |
| --split-single-dyn-dim | Enable split_single_dyn_dim compiler pass |
| --fp16 | Quantize for fp16 | | --fp16 | Quantize for fp16 |
| --int8 | Quantize for int8 | | --int8 | Quantize for int8 |
| --tolerance | Tolerance for errors | | --tolerance | Tolerance for errors |
......
...@@ -21,6 +21,6 @@ ...@@ -21,6 +21,6 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE. # THE SOFTWARE.
##################################################################################### #####################################################################################
tensorflow==2.9.3 tensorflow==2.11.1
onnxruntime onnxruntime
tokenizers tokenizers
\ No newline at end of file
...@@ -6,13 +6,12 @@ ARG PREFIX=/usr/local ...@@ -6,13 +6,12 @@ ARG PREFIX=/usr/local
RUN dpkg --add-architecture i386 RUN dpkg --add-architecture i386
# Add rocm repository # Add rocm repository
RUN sh -c 'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/5.4.2/ ubuntu main > /etc/apt/sources.list.d/rocm.list' RUN sh -c 'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/5.6/ focal main > /etc/apt/sources.list.d/rocm.list'
# Install dependencies # Install dependencies
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \
apt-utils \ apt-utils \
build-essential \ build-essential \
clang-format-10 \
cmake \ cmake \
curl \ curl \
doxygen \ doxygen \
...@@ -49,7 +48,7 @@ ENV LANG=C.UTF-8 ...@@ -49,7 +48,7 @@ ENV LANG=C.UTF-8
RUN pip3 install yapf==0.28.0 RUN pip3 install yapf==0.28.0
# Install doc requirements # Install doc requirements
ADD doc/requirements.txt /doc-requirements.txt ADD docs/.sphinx/requirements.txt /doc-requirements.txt
RUN pip3 install -r /doc-requirements.txt RUN pip3 install -r /doc-requirements.txt
# Install dependencies # Install dependencies
...@@ -59,4 +58,3 @@ ADD rbuild.ini /rbuild.ini ...@@ -59,4 +58,3 @@ ADD rbuild.ini /rbuild.ini
COPY ./tools/install_prereqs.sh / COPY ./tools/install_prereqs.sh /
RUN /install_prereqs.sh /usr/local / && rm /install_prereqs.sh RUN /install_prereqs.sh /usr/local / && rm /install_prereqs.sh
...@@ -24,7 +24,8 @@ ...@@ -24,7 +24,8 @@
google/protobuf@v3.11.0 -DCMAKE_POSITION_INDEPENDENT_CODE=On -X subdir -Dprotobuf_BUILD_TESTS=Off google/protobuf@v3.11.0 -DCMAKE_POSITION_INDEPENDENT_CODE=On -X subdir -Dprotobuf_BUILD_TESTS=Off
nlohmann/json@v3.8.0 nlohmann/json@v3.8.0
live-clones/blaze@v3.8 -X header -DHEADER_DIR=blaze -H sha256:d0ff011f47538285178908ea5f2cab46bb6a8f55b1edb6e03224a82dbc1a3212 live-clones/blaze@v3.8 -X header -DHEADER_DIR=blaze -H sha256:d0ff011f47538285178908ea5f2cab46bb6a8f55b1edb6e03224a82dbc1a3212
ROCmSoftwarePlatform/half@rocm-5.4.2 ROCmSoftwarePlatform/half@rocm-5.6.0
pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build
msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off
sqlite3@3.17 -DCMAKE_POSITION_INDEPENDENT_CODE=On sqlite3@3.17 -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCmSoftwarePlatform/composable_kernel@5172ec5280f14974beee2acf1af1db3b2670244c -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On
...@@ -22,6 +22,8 @@ ...@@ -22,6 +22,8 @@
# THE SOFTWARE. # THE SOFTWARE.
##################################################################################### #####################################################################################
include(ExportHeader)
include(ROCMInstallTargets) include(ROCMInstallTargets)
include(ROCMPackageConfigHelpers) include(ROCMPackageConfigHelpers)
include(RegisterOp) include(RegisterOp)
...@@ -94,6 +96,7 @@ add_library(migraphx ...@@ -94,6 +96,7 @@ add_library(migraphx
simplify_algebra.cpp simplify_algebra.cpp
simplify_reshapes.cpp simplify_reshapes.cpp
split_single_dyn_dim.cpp split_single_dyn_dim.cpp
target.cpp
tmp_dir.cpp tmp_dir.cpp
value.cpp value.cpp
verify_args.cpp verify_args.cpp
...@@ -126,10 +129,11 @@ register_migraphx_ops( ...@@ -126,10 +129,11 @@ register_migraphx_ops(
contiguous contiguous
convert convert
convolution convolution
convolution_backwards
cosh cosh
cos cos
deconvolution
dequantizelinear dequantizelinear
dimensions_of
div div
dot dot
elu elu
...@@ -195,6 +199,7 @@ register_migraphx_ops( ...@@ -195,6 +199,7 @@ register_migraphx_ops(
roialign roialign
round round
rsqrt rsqrt
run_on_target
scalar scalar
scatter_add scatter_add
scatter_mul scatter_mul
...@@ -255,6 +260,7 @@ endif() ...@@ -255,6 +260,7 @@ endif()
find_package(nlohmann_json 3.8.0 REQUIRED) find_package(nlohmann_json 3.8.0 REQUIRED)
target_link_libraries(migraphx PRIVATE nlohmann_json::nlohmann_json) target_link_libraries(migraphx PRIVATE nlohmann_json::nlohmann_json)
migraphx_generate_export_header(migraphx)
find_package(PkgConfig) find_package(PkgConfig)
pkg_check_modules(SQLITE3 REQUIRED IMPORTED_TARGET sqlite3) pkg_check_modules(SQLITE3 REQUIRED IMPORTED_TARGET sqlite3)
......
...@@ -26,6 +26,7 @@ add_library(migraphx_c ...@@ -26,6 +26,7 @@ add_library(migraphx_c
api.cpp api.cpp
) )
set_target_properties(migraphx_c PROPERTIES EXPORT_NAME c) set_target_properties(migraphx_c PROPERTIES EXPORT_NAME c)
migraphx_generate_export_header(migraphx_c DIRECTORY migraphx/api)
# migraphx_c is stable API interface library. SO version of this should be # migraphx_c is stable API interface library. SO version of this should be
# bumped when binary compatibility is broken. # bumped when binary compatibility is broken.
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include <migraphx/execution_environment.hpp> #include <migraphx/execution_environment.hpp>
#include <migraphx/migraphx.h> #include <migraphx/migraphx.h>
#include <migraphx/rank.hpp> #include <migraphx/rank.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
...@@ -43,7 +44,7 @@ namespace migraphx { ...@@ -43,7 +44,7 @@ namespace migraphx {
static thread_local bool disable_exception_catch = false; // NOLINT static thread_local bool disable_exception_catch = false; // NOLINT
extern "C" void migraphx_test_private_disable_exception_catch(bool b) extern "C" MIGRAPHX_C_EXPORT void migraphx_test_private_disable_exception_catch(bool b)
{ {
disable_exception_catch = b; disable_exception_catch = b;
} }
...@@ -145,6 +146,11 @@ void set_default_dim_value(onnx_options& options, size_t value) ...@@ -145,6 +146,11 @@ void set_default_dim_value(onnx_options& options, size_t value)
options.default_dim_value = value; options.default_dim_value = value;
} }
void set_default_dyn_dim_value(onnx_options& options, const shape::dynamic_dimension& dd)
{
options.default_dyn_dim_value = dd;
}
void set_default_loop_iterations(onnx_options& options, int64_t value) void set_default_loop_iterations(onnx_options& options, int64_t value)
{ {
options.max_loop_iterations = value; options.max_loop_iterations = value;
...@@ -161,6 +167,13 @@ void set_input_parameter_shape(onnx_options& options, ...@@ -161,6 +167,13 @@ void set_input_parameter_shape(onnx_options& options,
options.map_input_dims[std::string(name)] = std::move(dims); options.map_input_dims[std::string(name)] = std::move(dims);
} }
void set_dyn_input_parameter_shape(onnx_options& options,
const char* name,
std::vector<shape::dynamic_dimension> dyn_dims)
{
options.map_dyn_input_dims[std::string(name)] = std::move(dyn_dims);
}
void set_input_parameter_shape(tf_options& options, const char* name, std::vector<std::size_t> dims) void set_input_parameter_shape(tf_options& options, const char* name, std::vector<std::size_t> dims)
{ {
options.map_input_dims[std::string(name)] = std::move(dims); options.map_input_dims[std::string(name)] = std::move(dims);
...@@ -187,6 +200,12 @@ std::vector<const char*> get_names(const std::unordered_map<std::string, Value>& ...@@ -187,6 +200,12 @@ std::vector<const char*> get_names(const std::unordered_map<std::string, Value>&
return result; return result;
} }
template <class T>
std::set<T> make_set(const T* x, std::size_t n)
{
return {x, x + n};
}
void quantize_fp16_with_op_names(program& prog, std::vector<std::string>& names) void quantize_fp16_with_op_names(program& prog, std::vector<std::string>& names)
{ {
if(names.empty()) if(names.empty())
...@@ -346,7 +365,10 @@ const Target* object_cast(const U* x) ...@@ -346,7 +365,10 @@ const Target* object_cast(const U* x)
template <class T, class... Ts, class Target = std::remove_pointer_t<T>> template <class T, class... Ts, class Target = std::remove_pointer_t<T>>
Target* allocate(Ts&&... xs) Target* allocate(Ts&&... xs)
{ {
return new Target(std::forward<Ts>(xs)...); // NOLINT if constexpr(std::is_aggregate<Target>{})
return new Target{std::forward<Ts>(xs)...}; // NOLINT
else
return new Target(std::forward<Ts>(xs)...); // NOLINT
} }
template <class T> template <class T>
...@@ -409,6 +431,39 @@ struct manage_generic_ptr ...@@ -409,6 +431,39 @@ struct manage_generic_ptr
D deleter = nullptr; D deleter = nullptr;
}; };
extern "C" struct migraphx_optimals;
struct migraphx_optimals
{
template <class... Ts>
migraphx_optimals(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
std::set<size_t> object;
};
extern "C" struct migraphx_dynamic_dimension;
struct migraphx_dynamic_dimension
{
template <class... Ts>
migraphx_dynamic_dimension(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
migraphx::shape::dynamic_dimension object;
};
extern "C" struct migraphx_dynamic_dimensions;
struct migraphx_dynamic_dimensions
{
template <class... Ts>
migraphx_dynamic_dimensions(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
std::vector<migraphx::shape::dynamic_dimension> object;
};
extern "C" struct migraphx_shape; extern "C" struct migraphx_shape;
struct migraphx_shape struct migraphx_shape
{ {
...@@ -736,6 +791,152 @@ struct migraphx_experimental_custom_op ...@@ -736,6 +791,152 @@ struct migraphx_experimental_custom_op
} }
}; };
extern "C" migraphx_status migraphx_optimals_destroy(migraphx_optimals_t optimals)
{
auto api_error_result = migraphx::try_([&] { destroy((optimals)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_optimals_assign_to(migraphx_optimals_t output,
const_migraphx_optimals_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status
migraphx_optimals_create(migraphx_optimals_t* optimals, const size_t* ptr, size_t size)
{
auto api_error_result = migraphx::try_([&] {
*optimals = object_cast<migraphx_optimals_t>(
allocate<std::set<size_t>>(migraphx::make_set<size_t>((ptr), (size))));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_dynamic_dimension_destroy(migraphx_dynamic_dimension_t dynamic_dimension)
{
auto api_error_result = migraphx::try_([&] { destroy((dynamic_dimension)); });
return api_error_result;
}
extern "C" migraphx_status
migraphx_dynamic_dimension_assign_to(migraphx_dynamic_dimension_t output,
const_migraphx_dynamic_dimension_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_dynamic_dimension_create_min_max(
migraphx_dynamic_dimension_t* dynamic_dimension, size_t min, size_t max)
{
auto api_error_result = migraphx::try_([&] {
*dynamic_dimension = object_cast<migraphx_dynamic_dimension_t>(
allocate<migraphx::shape::dynamic_dimension>((min), (max)));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_dynamic_dimension_create_min_max_optimals(migraphx_dynamic_dimension_t* dynamic_dimension,
size_t min,
size_t max,
migraphx_optimals_t optimals)
{
auto api_error_result = migraphx::try_([&] {
if(optimals == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter optimals: Null pointer");
*dynamic_dimension = object_cast<migraphx_dynamic_dimension_t>(
allocate<migraphx::shape::dynamic_dimension>((min), (max), (optimals->object)));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_dynamic_dimension_is_fixed(bool* out, const_migraphx_dynamic_dimension_t dynamic_dimension)
{
auto api_error_result = migraphx::try_([&] {
if(dynamic_dimension == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter dynamic_dimension: Null pointer");
*out = (dynamic_dimension->object).is_fixed();
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_dynamic_dimension_equal(bool* out,
const_migraphx_dynamic_dimension_t dynamic_dimension,
const_migraphx_dynamic_dimension_t x)
{
auto api_error_result = migraphx::try_([&] {
if(dynamic_dimension == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter dynamic_dimension: Null pointer");
if(x == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter x: Null pointer");
*out = migraphx::equal((dynamic_dimension->object), (x->object));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_dynamic_dimensions_destroy(migraphx_dynamic_dimensions_t dynamic_dimensions)
{
auto api_error_result = migraphx::try_([&] { destroy((dynamic_dimensions)); });
return api_error_result;
}
extern "C" migraphx_status
migraphx_dynamic_dimensions_assign_to(migraphx_dynamic_dimensions_t output,
const_migraphx_dynamic_dimensions_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status
migraphx_dynamic_dimensions_create(migraphx_dynamic_dimensions_t* dynamic_dimensions,
const_migraphx_dynamic_dimension_t* ptr,
size_t size)
{
auto api_error_result = migraphx::try_([&] {
*dynamic_dimensions = object_cast<migraphx_dynamic_dimensions_t>(
allocate<std::vector<migraphx::shape::dynamic_dimension>>(
migraphx::to_obj_vector<const_migraphx_dynamic_dimension_t>((ptr), (size))));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_dynamic_dimensions_size(size_t* out, migraphx_dynamic_dimensions_t dynamic_dimensions)
{
auto api_error_result = migraphx::try_([&] {
if(dynamic_dimensions == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter dynamic_dimensions: Null pointer");
*out = (dynamic_dimensions->object).size();
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_dynamic_dimensions_get(const_migraphx_dynamic_dimension_t* out,
migraphx_dynamic_dimensions_t dynamic_dimensions,
size_t idx)
{
auto api_error_result = migraphx::try_([&] {
if(dynamic_dimensions == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter dynamic_dimensions: Null pointer");
*out = object_cast<const_migraphx_dynamic_dimension_t>(
&((dynamic_dimensions->object).at((idx))));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_shape_destroy(migraphx_shape_t shape) extern "C" migraphx_status migraphx_shape_destroy(migraphx_shape_t shape)
{ {
auto api_error_result = migraphx::try_([&] { destroy((shape)); }); auto api_error_result = migraphx::try_([&] { destroy((shape)); });
...@@ -794,6 +995,19 @@ extern "C" migraphx_status migraphx_shape_create_scalar(migraphx_shape_t* shape, ...@@ -794,6 +995,19 @@ extern "C" migraphx_status migraphx_shape_create_scalar(migraphx_shape_t* shape,
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_shape_create_dynamic(migraphx_shape_t* shape,
migraphx_shape_datatype_t type,
migraphx_dynamic_dimensions_t dims)
{
auto api_error_result = migraphx::try_([&] {
if(dims == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter dims: Null pointer");
*shape = object_cast<migraphx_shape_t>(
allocate<migraphx::shape>((migraphx::to_shape_type(type)), (dims->object)));
});
return api_error_result;
}
extern "C" migraphx_status extern "C" migraphx_status
migraphx_shape_lengths(const size_t** out, size_t* out_size, const_migraphx_shape_t shape) migraphx_shape_lengths(const size_t** out, size_t* out_size, const_migraphx_shape_t shape)
{ {
...@@ -824,6 +1038,17 @@ migraphx_shape_strides(const size_t** out, size_t* out_size, const_migraphx_shap ...@@ -824,6 +1038,17 @@ migraphx_shape_strides(const size_t** out, size_t* out_size, const_migraphx_shap
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_shape_dyn_dims(migraphx_dynamic_dimensions_t* out,
const_migraphx_shape_t shape)
{
auto api_error_result = migraphx::try_([&] {
if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
*out = allocate<migraphx_dynamic_dimensions_t>((shape->object).dyn_dims());
});
return api_error_result;
}
extern "C" migraphx_status migraphx_shape_type(migraphx_shape_datatype_t* out, extern "C" migraphx_status migraphx_shape_type(migraphx_shape_datatype_t* out,
const_migraphx_shape_t shape) const_migraphx_shape_t shape)
{ {
...@@ -857,6 +1082,16 @@ extern "C" migraphx_status migraphx_shape_bytes(size_t* out, const_migraphx_shap ...@@ -857,6 +1082,16 @@ extern "C" migraphx_status migraphx_shape_bytes(size_t* out, const_migraphx_shap
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_shape_ndim(size_t* out, const_migraphx_shape_t shape)
{
auto api_error_result = migraphx::try_([&] {
if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
*out = (shape->object).ndim();
});
return api_error_result;
}
extern "C" migraphx_status extern "C" migraphx_status
migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_shape_t x) migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_shape_t x)
{ {
...@@ -880,6 +1115,16 @@ extern "C" migraphx_status migraphx_shape_standard(bool* out, const_migraphx_sha ...@@ -880,6 +1115,16 @@ extern "C" migraphx_status migraphx_shape_standard(bool* out, const_migraphx_sha
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_shape_dynamic(bool* out, const_migraphx_shape_t shape)
{
auto api_error_result = migraphx::try_([&] {
if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
*out = (shape->object).dynamic();
});
return api_error_result;
}
extern "C" migraphx_status migraphx_shape_index(size_t* out, const_migraphx_shape_t shape, size_t i) extern "C" migraphx_status migraphx_shape_index(size_t* out, const_migraphx_shape_t shape, size_t i)
{ {
auto api_error_result = migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
...@@ -915,6 +1160,17 @@ migraphx_argument_create(migraphx_argument_t* argument, const_migraphx_shape_t s ...@@ -915,6 +1160,17 @@ migraphx_argument_create(migraphx_argument_t* argument, const_migraphx_shape_t s
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_argument_create_empty(migraphx_argument_t* argument,
const_migraphx_shape_t shape)
{
auto api_error_result = migraphx::try_([&] {
if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
*argument = object_cast<migraphx_argument_t>(allocate<migraphx::argument>((shape->object)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_argument_shape(const_migraphx_shape_t* out, extern "C" migraphx_status migraphx_argument_shape(const_migraphx_shape_t* out,
const_migraphx_argument_t argument) const_migraphx_argument_t argument)
{ {
...@@ -1590,6 +1846,19 @@ extern "C" migraphx_status migraphx_onnx_options_set_input_parameter_shape( ...@@ -1590,6 +1846,19 @@ extern "C" migraphx_status migraphx_onnx_options_set_input_parameter_shape(
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_onnx_options_set_dyn_input_parameter_shape(
migraphx_onnx_options_t onnx_options, const char* name, migraphx_dynamic_dimensions_t dims)
{
auto api_error_result = migraphx::try_([&] {
if(onnx_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter onnx_options: Null pointer");
if(dims == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter dims: Null pointer");
migraphx::set_dyn_input_parameter_shape((onnx_options->object), (name), (dims->object));
});
return api_error_result;
}
extern "C" migraphx_status extern "C" migraphx_status
migraphx_onnx_options_set_default_dim_value(migraphx_onnx_options_t onnx_options, size_t value) migraphx_onnx_options_set_default_dim_value(migraphx_onnx_options_t onnx_options, size_t value)
{ {
...@@ -1601,6 +1870,20 @@ migraphx_onnx_options_set_default_dim_value(migraphx_onnx_options_t onnx_options ...@@ -1601,6 +1870,20 @@ migraphx_onnx_options_set_default_dim_value(migraphx_onnx_options_t onnx_options
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status
migraphx_onnx_options_set_default_dyn_dim_value(migraphx_onnx_options_t onnx_options,
const_migraphx_dynamic_dimension_t dd)
{
auto api_error_result = migraphx::try_([&] {
if(onnx_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter onnx_options: Null pointer");
if(dd == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter dd: Null pointer");
migraphx::set_default_dyn_dim_value((onnx_options->object), (dd->object));
});
return api_error_result;
}
extern "C" migraphx_status extern "C" migraphx_status
migraphx_onnx_options_set_default_loop_iterations(migraphx_onnx_options_t onnx_options, migraphx_onnx_options_set_default_loop_iterations(migraphx_onnx_options_t onnx_options,
int64_t value) int64_t value)
......
...@@ -26,6 +26,9 @@ ...@@ -26,6 +26,9 @@
#include <stdlib.h> #include <stdlib.h>
#include <stdbool.h> #include <stdbool.h>
#include <migraphx/api/export.h>
// Add new types here // Add new types here
// clang-format off // clang-format off
#define MIGRAPHX_SHAPE_VISIT_TYPES(m) \ #define MIGRAPHX_SHAPE_VISIT_TYPES(m) \
...@@ -66,6 +69,15 @@ typedef enum ...@@ -66,6 +69,15 @@ typedef enum
} migraphx_shape_datatype_t; } migraphx_shape_datatype_t;
#undef MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES #undef MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES
typedef struct migraphx_optimals* migraphx_optimals_t;
typedef const struct migraphx_optimals* const_migraphx_optimals_t;
typedef struct migraphx_dynamic_dimension* migraphx_dynamic_dimension_t;
typedef const struct migraphx_dynamic_dimension* const_migraphx_dynamic_dimension_t;
typedef struct migraphx_dynamic_dimensions* migraphx_dynamic_dimensions_t;
typedef const struct migraphx_dynamic_dimensions* const_migraphx_dynamic_dimensions_t;
typedef struct migraphx_shape* migraphx_shape_t; typedef struct migraphx_shape* migraphx_shape_t;
typedef const struct migraphx_shape* const_migraphx_shape_t; typedef const struct migraphx_shape* const_migraphx_shape_t;
...@@ -157,360 +169,460 @@ typedef migraphx_status (*migraphx_experimental_custom_op_copy)(void** out, void ...@@ -157,360 +169,460 @@ typedef migraphx_status (*migraphx_experimental_custom_op_copy)(void** out, void
typedef migraphx_status (*migraphx_experimental_custom_op_delete)(void* input); typedef migraphx_status (*migraphx_experimental_custom_op_delete)(void* input);
migraphx_status migraphx_shape_destroy(migraphx_shape_t shape); MIGRAPHX_C_EXPORT migraphx_status migraphx_optimals_destroy(migraphx_optimals_t optimals);
MIGRAPHX_C_EXPORT migraphx_status migraphx_optimals_assign_to(migraphx_optimals_t output,
const_migraphx_optimals_t input);
MIGRAPHX_C_EXPORT migraphx_status migraphx_optimals_create(migraphx_optimals_t* optimals,
const size_t* ptr,
size_t size);
MIGRAPHX_C_EXPORT migraphx_status
migraphx_dynamic_dimension_destroy(migraphx_dynamic_dimension_t dynamic_dimension);
MIGRAPHX_C_EXPORT migraphx_status migraphx_dynamic_dimension_assign_to(
migraphx_dynamic_dimension_t output, const_migraphx_dynamic_dimension_t input);
MIGRAPHX_C_EXPORT migraphx_status migraphx_dynamic_dimension_create_min_max(
migraphx_dynamic_dimension_t* dynamic_dimension, size_t min, size_t max);
MIGRAPHX_C_EXPORT migraphx_status
migraphx_dynamic_dimension_create_min_max_optimals(migraphx_dynamic_dimension_t* dynamic_dimension,
size_t min,
size_t max,
migraphx_optimals_t optimals);
MIGRAPHX_C_EXPORT migraphx_status migraphx_dynamic_dimension_is_fixed(
bool* out, const_migraphx_dynamic_dimension_t dynamic_dimension);
MIGRAPHX_C_EXPORT migraphx_status
migraphx_dynamic_dimension_equal(bool* out,
const_migraphx_dynamic_dimension_t dynamic_dimension,
const_migraphx_dynamic_dimension_t x);
MIGRAPHX_C_EXPORT migraphx_status
migraphx_dynamic_dimensions_destroy(migraphx_dynamic_dimensions_t dynamic_dimensions);
MIGRAPHX_C_EXPORT migraphx_status migraphx_dynamic_dimensions_assign_to(
migraphx_dynamic_dimensions_t output, const_migraphx_dynamic_dimensions_t input);
MIGRAPHX_C_EXPORT migraphx_status
migraphx_dynamic_dimensions_create(migraphx_dynamic_dimensions_t* dynamic_dimensions,
const_migraphx_dynamic_dimension_t* ptr,
size_t size);
MIGRAPHX_C_EXPORT migraphx_status
migraphx_dynamic_dimensions_size(size_t* out, migraphx_dynamic_dimensions_t dynamic_dimensions);
migraphx_status migraphx_shape_assign_to(migraphx_shape_t output, const_migraphx_shape_t input); MIGRAPHX_C_EXPORT migraphx_status
migraphx_dynamic_dimensions_get(const_migraphx_dynamic_dimension_t* out,
migraphx_dynamic_dimensions_t dynamic_dimensions,
size_t idx);
migraphx_status migraphx_shape_create(migraphx_shape_t* shape, MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_destroy(migraphx_shape_t shape);
migraphx_shape_datatype_t type,
size_t* lengths,
size_t lengths_size);
migraphx_status migraphx_shape_create_with_strides(migraphx_shape_t* shape, MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_assign_to(migraphx_shape_t output,
migraphx_shape_datatype_t type, const_migraphx_shape_t input);
size_t* lengths,
size_t lengths_size,
size_t* strides,
size_t strides_size);
migraphx_status migraphx_shape_create_scalar(migraphx_shape_t* shape, MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_create(migraphx_shape_t* shape,
migraphx_shape_datatype_t type); migraphx_shape_datatype_t type,
size_t* lengths,
size_t lengths_size);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_create_with_strides(migraphx_shape_t* shape,
migraphx_shape_lengths(const size_t** out, size_t* out_size, const_migraphx_shape_t shape); migraphx_shape_datatype_t type,
size_t* lengths,
size_t lengths_size,
size_t* strides,
size_t strides_size);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_create_scalar(migraphx_shape_t* shape,
migraphx_shape_strides(const size_t** out, size_t* out_size, const_migraphx_shape_t shape); migraphx_shape_datatype_t type);
migraphx_status migraphx_shape_type(migraphx_shape_datatype_t* out, const_migraphx_shape_t shape); MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_create_dynamic(migraphx_shape_t* shape,
migraphx_shape_datatype_t type,
migraphx_dynamic_dimensions_t dims);
migraphx_status migraphx_shape_elements(size_t* out, const_migraphx_shape_t shape); MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_lengths(const size_t** out,
size_t* out_size,
const_migraphx_shape_t shape);
migraphx_status migraphx_shape_bytes(size_t* out, const_migraphx_shape_t shape); MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_strides(const size_t** out,
size_t* out_size,
const_migraphx_shape_t shape);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_dyn_dims(migraphx_dynamic_dimensions_t* out,
migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_shape_t x); const_migraphx_shape_t shape);
migraphx_status migraphx_shape_standard(bool* out, const_migraphx_shape_t shape); MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_type(migraphx_shape_datatype_t* out,
const_migraphx_shape_t shape);
migraphx_status migraphx_shape_index(size_t* out, const_migraphx_shape_t shape, size_t i); MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_elements(size_t* out,
const_migraphx_shape_t shape);
migraphx_status migraphx_argument_destroy(migraphx_argument_t argument); MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_bytes(size_t* out, const_migraphx_shape_t shape);
migraphx_status migraphx_argument_assign_to(migraphx_argument_t output, MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_ndim(size_t* out, const_migraphx_shape_t shape);
const_migraphx_argument_t input);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_equal(bool* out,
migraphx_argument_create(migraphx_argument_t* argument, const_migraphx_shape_t shape, void* buffer); const_migraphx_shape_t shape,
const_migraphx_shape_t x);
migraphx_status migraphx_argument_shape(const_migraphx_shape_t* out, MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_standard(bool* out, const_migraphx_shape_t shape);
const_migraphx_argument_t argument);
migraphx_status migraphx_argument_buffer(char** out, const_migraphx_argument_t argument); MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_dynamic(bool* out, const_migraphx_shape_t shape);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_index(size_t* out,
migraphx_argument_equal(bool* out, const_migraphx_argument_t argument, const_migraphx_argument_t x); const_migraphx_shape_t shape,
size_t i);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_argument_destroy(migraphx_argument_t argument);
migraphx_argument_generate(migraphx_argument_t* out, const_migraphx_shape_t s, size_t seed);
migraphx_status migraphx_target_destroy(migraphx_target_t target); MIGRAPHX_C_EXPORT migraphx_status migraphx_argument_assign_to(migraphx_argument_t output,
const_migraphx_argument_t input);
migraphx_status migraphx_target_assign_to(migraphx_target_t output, const_migraphx_target_t input); MIGRAPHX_C_EXPORT migraphx_status migraphx_argument_create(migraphx_argument_t* argument,
const_migraphx_shape_t shape,
void* buffer);
migraphx_status migraphx_target_create(migraphx_target_t* target, const char* name); MIGRAPHX_C_EXPORT migraphx_status migraphx_argument_create_empty(migraphx_argument_t* argument,
const_migraphx_shape_t shape);
migraphx_status migraphx_program_parameter_shapes_destroy( MIGRAPHX_C_EXPORT migraphx_status migraphx_argument_shape(const_migraphx_shape_t* out,
const_migraphx_argument_t argument);
MIGRAPHX_C_EXPORT migraphx_status migraphx_argument_buffer(char** out,
const_migraphx_argument_t argument);
MIGRAPHX_C_EXPORT migraphx_status migraphx_argument_equal(bool* out,
const_migraphx_argument_t argument,
const_migraphx_argument_t x);
MIGRAPHX_C_EXPORT migraphx_status migraphx_argument_generate(migraphx_argument_t* out,
const_migraphx_shape_t s,
size_t seed);
MIGRAPHX_C_EXPORT migraphx_status migraphx_target_destroy(migraphx_target_t target);
MIGRAPHX_C_EXPORT migraphx_status migraphx_target_assign_to(migraphx_target_t output,
const_migraphx_target_t input);
MIGRAPHX_C_EXPORT migraphx_status migraphx_target_create(migraphx_target_t* target,
const char* name);
MIGRAPHX_C_EXPORT migraphx_status migraphx_program_parameter_shapes_destroy(
migraphx_program_parameter_shapes_t program_parameter_shapes); migraphx_program_parameter_shapes_t program_parameter_shapes);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_program_parameter_shapes_assign_to(
migraphx_program_parameter_shapes_assign_to(migraphx_program_parameter_shapes_t output, migraphx_program_parameter_shapes_t output, const_migraphx_program_parameter_shapes_t input);
const_migraphx_program_parameter_shapes_t input);
migraphx_status migraphx_program_parameter_shapes_size( MIGRAPHX_C_EXPORT migraphx_status migraphx_program_parameter_shapes_size(
size_t* out, migraphx_program_parameter_shapes_t program_parameter_shapes); size_t* out, migraphx_program_parameter_shapes_t program_parameter_shapes);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status
migraphx_program_parameter_shapes_get(const_migraphx_shape_t* out, migraphx_program_parameter_shapes_get(const_migraphx_shape_t* out,
migraphx_program_parameter_shapes_t program_parameter_shapes, migraphx_program_parameter_shapes_t program_parameter_shapes,
const char* name); const char* name);
migraphx_status migraphx_program_parameter_shapes_names( MIGRAPHX_C_EXPORT migraphx_status migraphx_program_parameter_shapes_names(
const char** out, migraphx_program_parameter_shapes_t program_parameter_shapes); const char** out, migraphx_program_parameter_shapes_t program_parameter_shapes);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status
migraphx_program_parameters_destroy(migraphx_program_parameters_t program_parameters); migraphx_program_parameters_destroy(migraphx_program_parameters_t program_parameters);
migraphx_status migraphx_program_parameters_assign_to(migraphx_program_parameters_t output, MIGRAPHX_C_EXPORT migraphx_status migraphx_program_parameters_assign_to(
const_migraphx_program_parameters_t input); migraphx_program_parameters_t output, const_migraphx_program_parameters_t input);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status
migraphx_program_parameters_create(migraphx_program_parameters_t* program_parameters); migraphx_program_parameters_create(migraphx_program_parameters_t* program_parameters);
migraphx_status migraphx_program_parameters_add(migraphx_program_parameters_t program_parameters, MIGRAPHX_C_EXPORT migraphx_status
const char* name, migraphx_program_parameters_add(migraphx_program_parameters_t program_parameters,
const_migraphx_argument_t argument); const char* name,
const_migraphx_argument_t argument);
migraphx_status migraphx_arguments_destroy(migraphx_arguments_t arguments); MIGRAPHX_C_EXPORT migraphx_status migraphx_arguments_destroy(migraphx_arguments_t arguments);
migraphx_status migraphx_arguments_assign_to(migraphx_arguments_t output, MIGRAPHX_C_EXPORT migraphx_status migraphx_arguments_assign_to(migraphx_arguments_t output,
const_migraphx_arguments_t input); const_migraphx_arguments_t input);
migraphx_status migraphx_arguments_size(size_t* out, migraphx_arguments_t arguments); MIGRAPHX_C_EXPORT migraphx_status migraphx_arguments_size(size_t* out,
migraphx_arguments_t arguments);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_arguments_get(const_migraphx_argument_t* out,
migraphx_arguments_get(const_migraphx_argument_t* out, migraphx_arguments_t arguments, size_t idx); migraphx_arguments_t arguments,
size_t idx);
migraphx_status migraphx_shapes_destroy(migraphx_shapes_t shapes); MIGRAPHX_C_EXPORT migraphx_status migraphx_shapes_destroy(migraphx_shapes_t shapes);
migraphx_status migraphx_shapes_assign_to(migraphx_shapes_t output, const_migraphx_shapes_t input); MIGRAPHX_C_EXPORT migraphx_status migraphx_shapes_assign_to(migraphx_shapes_t output,
const_migraphx_shapes_t input);
migraphx_status migraphx_shapes_size(size_t* out, migraphx_shapes_t shapes); MIGRAPHX_C_EXPORT migraphx_status migraphx_shapes_size(size_t* out, migraphx_shapes_t shapes);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_shapes_get(const_migraphx_shape_t* out,
migraphx_shapes_get(const_migraphx_shape_t* out, migraphx_shapes_t shapes, size_t idx); migraphx_shapes_t shapes,
size_t idx);
migraphx_status migraphx_instruction_destroy(migraphx_instruction_t instruction); MIGRAPHX_C_EXPORT migraphx_status migraphx_instruction_destroy(migraphx_instruction_t instruction);
migraphx_status migraphx_instruction_assign_to(migraphx_instruction_t output, MIGRAPHX_C_EXPORT migraphx_status
const_migraphx_instruction_t input); migraphx_instruction_assign_to(migraphx_instruction_t output, const_migraphx_instruction_t input);
migraphx_status migraphx_instructions_destroy(migraphx_instructions_t instructions); MIGRAPHX_C_EXPORT migraphx_status
migraphx_instructions_destroy(migraphx_instructions_t instructions);
migraphx_status migraphx_instructions_assign_to(migraphx_instructions_t output, MIGRAPHX_C_EXPORT migraphx_status migraphx_instructions_assign_to(
const_migraphx_instructions_t input); migraphx_instructions_t output, const_migraphx_instructions_t input);
migraphx_status migraphx_instructions_create(migraphx_instructions_t* instructions, MIGRAPHX_C_EXPORT migraphx_status migraphx_instructions_create(
const_migraphx_instruction_t* ptr, migraphx_instructions_t* instructions, const_migraphx_instruction_t* ptr, size_t size);
size_t size);
migraphx_status migraphx_modules_destroy(migraphx_modules_t modules); MIGRAPHX_C_EXPORT migraphx_status migraphx_modules_destroy(migraphx_modules_t modules);
migraphx_status migraphx_modules_assign_to(migraphx_modules_t output, MIGRAPHX_C_EXPORT migraphx_status migraphx_modules_assign_to(migraphx_modules_t output,
const_migraphx_modules_t input); const_migraphx_modules_t input);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_modules_create(migraphx_modules_t* modules,
migraphx_modules_create(migraphx_modules_t* modules, migraphx_module_t* ptr, size_t size); migraphx_module_t* ptr,
size_t size);
migraphx_status migraphx_module_create(migraphx_module_t* module, char* name); MIGRAPHX_C_EXPORT migraphx_status migraphx_module_create(migraphx_module_t* module, char* name);
migraphx_status migraphx_module_print(const_migraphx_module_t module); MIGRAPHX_C_EXPORT migraphx_status migraphx_module_print(const_migraphx_module_t module);
migraphx_status migraphx_module_add_instruction(migraphx_instruction_t* out, MIGRAPHX_C_EXPORT migraphx_status migraphx_module_add_instruction(migraphx_instruction_t* out,
migraphx_module_t module, migraphx_module_t module,
migraphx_operation_t op, migraphx_operation_t op,
migraphx_instructions_t args); migraphx_instructions_t args);
migraphx_status migraphx_module_add_instruction_with_mod_args(migraphx_instruction_t* out, MIGRAPHX_C_EXPORT migraphx_status
migraphx_module_t module, migraphx_module_add_instruction_with_mod_args(migraphx_instruction_t* out,
migraphx_operation_t op, migraphx_module_t module,
migraphx_instructions_t args, migraphx_operation_t op,
migraphx_modules_t module_refs); migraphx_instructions_t args,
migraphx_modules_t module_refs);
migraphx_status migraphx_module_add_literal(migraphx_instruction_t* out, MIGRAPHX_C_EXPORT migraphx_status migraphx_module_add_literal(migraphx_instruction_t* out,
migraphx_module_t module, migraphx_module_t module,
const_migraphx_shape_t shape, const_migraphx_shape_t shape,
const char* buffer); const char* buffer);
migraphx_status migraphx_module_add_parameter(migraphx_instruction_t* out, MIGRAPHX_C_EXPORT migraphx_status migraphx_module_add_parameter(migraphx_instruction_t* out,
migraphx_module_t module, migraphx_module_t module,
const char* name, const char* name,
const_migraphx_shape_t shape); const_migraphx_shape_t shape);
migraphx_status migraphx_module_add_return(migraphx_instruction_t* out, MIGRAPHX_C_EXPORT migraphx_status migraphx_module_add_return(migraphx_instruction_t* out,
migraphx_module_t module, migraphx_module_t module,
migraphx_instructions_t args); migraphx_instructions_t args);
migraphx_status migraphx_module_add_allocation(migraphx_instruction_t* out, MIGRAPHX_C_EXPORT migraphx_status migraphx_module_add_allocation(migraphx_instruction_t* out,
migraphx_module_t module, migraphx_module_t module,
const_migraphx_shape_t s); const_migraphx_shape_t s);
migraphx_status migraphx_program_destroy(migraphx_program_t program); MIGRAPHX_C_EXPORT migraphx_status migraphx_program_destroy(migraphx_program_t program);
migraphx_status migraphx_program_assign_to(migraphx_program_t output, MIGRAPHX_C_EXPORT migraphx_status migraphx_program_assign_to(migraphx_program_t output,
const_migraphx_program_t input); const_migraphx_program_t input);
migraphx_status migraphx_program_create(migraphx_program_t* program); MIGRAPHX_C_EXPORT migraphx_status migraphx_program_create(migraphx_program_t* program);
migraphx_status migraphx_program_get_main_module(migraphx_module_t* out, MIGRAPHX_C_EXPORT migraphx_status migraphx_program_get_main_module(migraphx_module_t* out,
migraphx_program_t program); migraphx_program_t program);
migraphx_status migraphx_program_create_module(migraphx_module_t* out, MIGRAPHX_C_EXPORT migraphx_status migraphx_program_create_module(migraphx_module_t* out,
migraphx_program_t program, migraphx_program_t program,
const char* name); const char* name);
migraphx_status migraphx_program_compile(migraphx_program_t program, MIGRAPHX_C_EXPORT migraphx_status migraphx_program_compile(migraphx_program_t program,
migraphx_target_t target, migraphx_target_t target,
migraphx_compile_options_t options); migraphx_compile_options_t options);
migraphx_status migraphx_program_get_parameter_shapes(migraphx_program_parameter_shapes_t* out, MIGRAPHX_C_EXPORT migraphx_status migraphx_program_get_parameter_shapes(
migraphx_program_t program); migraphx_program_parameter_shapes_t* out, migraphx_program_t program);
migraphx_status migraphx_program_get_output_shapes(migraphx_shapes_t* out, MIGRAPHX_C_EXPORT migraphx_status migraphx_program_get_output_shapes(migraphx_shapes_t* out,
migraphx_program_t program); migraphx_program_t program);
migraphx_status migraphx_program_print(const_migraphx_program_t program); MIGRAPHX_C_EXPORT migraphx_status migraphx_program_print(const_migraphx_program_t program);
migraphx_status migraphx_program_sort(migraphx_program_t program); MIGRAPHX_C_EXPORT migraphx_status migraphx_program_sort(migraphx_program_t program);
migraphx_status migraphx_program_run(migraphx_arguments_t* out, MIGRAPHX_C_EXPORT migraphx_status migraphx_program_run(migraphx_arguments_t* out,
migraphx_program_t program, migraphx_program_t program,
migraphx_program_parameters_t params); migraphx_program_parameters_t params);
migraphx_status migraphx_program_run_async(migraphx_arguments_t* out, MIGRAPHX_C_EXPORT migraphx_status migraphx_program_run_async(migraphx_arguments_t* out,
migraphx_program_t program, migraphx_program_t program,
migraphx_program_parameters_t params, migraphx_program_parameters_t params,
void* s, void* s,
const char* name); const char* name);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_program_equal(bool* out,
migraphx_program_equal(bool* out, const_migraphx_program_t program, const_migraphx_program_t x); const_migraphx_program_t program,
const_migraphx_program_t x);
migraphx_status migraphx_program_experimental_get_context(migraphx_context_t* out, MIGRAPHX_C_EXPORT migraphx_status migraphx_program_experimental_get_context(
const_migraphx_program_t program); migraphx_context_t* out, const_migraphx_program_t program);
migraphx_status migraphx_operation_destroy(migraphx_operation_t operation); MIGRAPHX_C_EXPORT migraphx_status migraphx_operation_destroy(migraphx_operation_t operation);
migraphx_status migraphx_operation_assign_to(migraphx_operation_t output, MIGRAPHX_C_EXPORT migraphx_status migraphx_operation_assign_to(migraphx_operation_t output,
const_migraphx_operation_t input); const_migraphx_operation_t input);
migraphx_status migraphx_operation_create(migraphx_operation_t* operation, MIGRAPHX_C_EXPORT migraphx_status migraphx_operation_create(migraphx_operation_t* operation,
const char* name, const char* name,
const char* attributes, const char* attributes,
...); ...);
migraphx_status migraphx_operation_name(char* out, size_t out_size, migraphx_operation_t operation); MIGRAPHX_C_EXPORT migraphx_status migraphx_operation_name(char* out,
size_t out_size,
migraphx_operation_t operation);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_load(migraphx_program_t* out,
migraphx_load(migraphx_program_t* out, const char* name, migraphx_file_options_t options); const char* name,
migraphx_file_options_t options);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_save(migraphx_program_t p,
migraphx_save(migraphx_program_t p, const char* name, migraphx_file_options_t options); const char* name,
migraphx_file_options_t options);
migraphx_status migraphx_onnx_options_destroy(migraphx_onnx_options_t onnx_options); MIGRAPHX_C_EXPORT migraphx_status
migraphx_onnx_options_destroy(migraphx_onnx_options_t onnx_options);
migraphx_status migraphx_onnx_options_assign_to(migraphx_onnx_options_t output, MIGRAPHX_C_EXPORT migraphx_status migraphx_onnx_options_assign_to(
const_migraphx_onnx_options_t input); migraphx_onnx_options_t output, const_migraphx_onnx_options_t input);
migraphx_status migraphx_onnx_options_create(migraphx_onnx_options_t* onnx_options); MIGRAPHX_C_EXPORT migraphx_status
migraphx_onnx_options_create(migraphx_onnx_options_t* onnx_options);
migraphx_status migraphx_onnx_options_set_input_parameter_shape( MIGRAPHX_C_EXPORT migraphx_status migraphx_onnx_options_set_input_parameter_shape(
migraphx_onnx_options_t onnx_options, const char* name, size_t* dims, size_t dims_size); migraphx_onnx_options_t onnx_options, const char* name, size_t* dims, size_t dims_size);
migraphx_status migraphx_onnx_options_set_default_dim_value(migraphx_onnx_options_t onnx_options, MIGRAPHX_C_EXPORT migraphx_status migraphx_onnx_options_set_dyn_input_parameter_shape(
size_t value); migraphx_onnx_options_t onnx_options, const char* name, migraphx_dynamic_dimensions_t dims);
MIGRAPHX_C_EXPORT migraphx_status
migraphx_onnx_options_set_default_dim_value(migraphx_onnx_options_t onnx_options, size_t value);
MIGRAPHX_C_EXPORT migraphx_status migraphx_onnx_options_set_default_dyn_dim_value(
migraphx_onnx_options_t onnx_options, const_migraphx_dynamic_dimension_t dd);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_onnx_options_set_default_loop_iterations(
migraphx_onnx_options_set_default_loop_iterations(migraphx_onnx_options_t onnx_options, migraphx_onnx_options_t onnx_options, int64_t value);
int64_t value);
migraphx_status migraphx_file_options_destroy(migraphx_file_options_t file_options); MIGRAPHX_C_EXPORT migraphx_status
migraphx_file_options_destroy(migraphx_file_options_t file_options);
migraphx_status migraphx_file_options_assign_to(migraphx_file_options_t output, MIGRAPHX_C_EXPORT migraphx_status migraphx_file_options_assign_to(
const_migraphx_file_options_t input); migraphx_file_options_t output, const_migraphx_file_options_t input);
migraphx_status migraphx_file_options_create(migraphx_file_options_t* file_options); MIGRAPHX_C_EXPORT migraphx_status
migraphx_file_options_create(migraphx_file_options_t* file_options);
migraphx_status migraphx_file_options_set_file_format(migraphx_file_options_t file_options, MIGRAPHX_C_EXPORT migraphx_status
const char* format); migraphx_file_options_set_file_format(migraphx_file_options_t file_options, const char* format);
migraphx_status migraphx_compile_options_destroy(migraphx_compile_options_t compile_options); MIGRAPHX_C_EXPORT migraphx_status
migraphx_compile_options_destroy(migraphx_compile_options_t compile_options);
migraphx_status migraphx_compile_options_assign_to(migraphx_compile_options_t output, MIGRAPHX_C_EXPORT migraphx_status migraphx_compile_options_assign_to(
const_migraphx_compile_options_t input); migraphx_compile_options_t output, const_migraphx_compile_options_t input);
migraphx_status migraphx_compile_options_create(migraphx_compile_options_t* compile_options); MIGRAPHX_C_EXPORT migraphx_status
migraphx_compile_options_create(migraphx_compile_options_t* compile_options);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status
migraphx_compile_options_set_offload_copy(migraphx_compile_options_t compile_options, bool value); migraphx_compile_options_set_offload_copy(migraphx_compile_options_t compile_options, bool value);
migraphx_status migraphx_compile_options_set_fast_math(migraphx_compile_options_t compile_options, MIGRAPHX_C_EXPORT migraphx_status
bool value); migraphx_compile_options_set_fast_math(migraphx_compile_options_t compile_options, bool value);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_compile_options_set_exhaustive_tune_flag(
migraphx_compile_options_set_exhaustive_tune_flag(migraphx_compile_options_t compile_options, migraphx_compile_options_t compile_options, bool value);
bool value);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_parse_onnx(migraphx_program_t* out,
migraphx_parse_onnx(migraphx_program_t* out, const char* name, migraphx_onnx_options_t options); const char* name,
migraphx_onnx_options_t options);
migraphx_status migraphx_parse_onnx_buffer(migraphx_program_t* out, MIGRAPHX_C_EXPORT migraphx_status migraphx_parse_onnx_buffer(migraphx_program_t* out,
const void* data, const void* data,
size_t size, size_t size,
migraphx_onnx_options_t options); migraphx_onnx_options_t options);
migraphx_status migraphx_tf_options_destroy(migraphx_tf_options_t tf_options); MIGRAPHX_C_EXPORT migraphx_status migraphx_tf_options_destroy(migraphx_tf_options_t tf_options);
migraphx_status migraphx_tf_options_assign_to(migraphx_tf_options_t output, MIGRAPHX_C_EXPORT migraphx_status migraphx_tf_options_assign_to(migraphx_tf_options_t output,
const_migraphx_tf_options_t input); const_migraphx_tf_options_t input);
migraphx_status migraphx_tf_options_create(migraphx_tf_options_t* tf_options); MIGRAPHX_C_EXPORT migraphx_status migraphx_tf_options_create(migraphx_tf_options_t* tf_options);
migraphx_status migraphx_tf_options_set_nhwc(migraphx_tf_options_t tf_options, bool is_nhwc); MIGRAPHX_C_EXPORT migraphx_status migraphx_tf_options_set_nhwc(migraphx_tf_options_t tf_options,
bool is_nhwc);
migraphx_status migraphx_tf_options_set_input_parameter_shape(migraphx_tf_options_t tf_options, MIGRAPHX_C_EXPORT migraphx_status migraphx_tf_options_set_input_parameter_shape(
const char* name, migraphx_tf_options_t tf_options, const char* name, size_t* dims, size_t dims_size);
size_t* dims,
size_t dims_size);
migraphx_status migraphx_tf_options_set_default_dim_value(migraphx_tf_options_t tf_options, MIGRAPHX_C_EXPORT migraphx_status
size_t value); migraphx_tf_options_set_default_dim_value(migraphx_tf_options_t tf_options, size_t value);
migraphx_status migraphx_tf_options_set_output_names(migraphx_tf_options_t tf_options, MIGRAPHX_C_EXPORT migraphx_status migraphx_tf_options_set_output_names(
const char** names, migraphx_tf_options_t tf_options, const char** names, size_t names_size);
size_t names_size);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_parse_tf(migraphx_program_t* out,
migraphx_parse_tf(migraphx_program_t* out, const char* name, migraphx_tf_options_t options); const char* name,
migraphx_tf_options_t options);
migraphx_status migraphx_quantize_op_names_destroy(migraphx_quantize_op_names_t quantize_op_names); MIGRAPHX_C_EXPORT migraphx_status
migraphx_quantize_op_names_destroy(migraphx_quantize_op_names_t quantize_op_names);
migraphx_status migraphx_quantize_op_names_assign_to(migraphx_quantize_op_names_t output, MIGRAPHX_C_EXPORT migraphx_status migraphx_quantize_op_names_assign_to(
const_migraphx_quantize_op_names_t input); migraphx_quantize_op_names_t output, const_migraphx_quantize_op_names_t input);
migraphx_status migraphx_quantize_op_names_create(migraphx_quantize_op_names_t* quantize_op_names); MIGRAPHX_C_EXPORT migraphx_status
migraphx_quantize_op_names_create(migraphx_quantize_op_names_t* quantize_op_names);
migraphx_status migraphx_quantize_op_names_add(migraphx_quantize_op_names_t quantize_op_names, MIGRAPHX_C_EXPORT migraphx_status
const char* name); migraphx_quantize_op_names_add(migraphx_quantize_op_names_t quantize_op_names, const char* name);
migraphx_status migraphx_quantize_fp16_with_op_names(migraphx_program_t prog, MIGRAPHX_C_EXPORT migraphx_status
migraphx_quantize_op_names_t name); migraphx_quantize_fp16_with_op_names(migraphx_program_t prog, migraphx_quantize_op_names_t name);
migraphx_status migraphx_quantize_fp16(migraphx_program_t prog); MIGRAPHX_C_EXPORT migraphx_status migraphx_quantize_fp16(migraphx_program_t prog);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status
migraphx_quantize_int8_options_destroy(migraphx_quantize_int8_options_t quantize_int8_options); migraphx_quantize_int8_options_destroy(migraphx_quantize_int8_options_t quantize_int8_options);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_quantize_int8_options_assign_to(
migraphx_quantize_int8_options_assign_to(migraphx_quantize_int8_options_t output, migraphx_quantize_int8_options_t output, const_migraphx_quantize_int8_options_t input);
const_migraphx_quantize_int8_options_t input);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status
migraphx_quantize_int8_options_create(migraphx_quantize_int8_options_t* quantize_int8_options); migraphx_quantize_int8_options_create(migraphx_quantize_int8_options_t* quantize_int8_options);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_quantize_int8_options_add_op_name(
migraphx_quantize_int8_options_add_op_name(migraphx_quantize_int8_options_t quantize_int8_options, migraphx_quantize_int8_options_t quantize_int8_options, const char* name);
const char* name);
migraphx_status migraphx_quantize_int8_options_add_calibration_data( MIGRAPHX_C_EXPORT migraphx_status migraphx_quantize_int8_options_add_calibration_data(
migraphx_quantize_int8_options_t quantize_int8_options, migraphx_program_parameters_t data); migraphx_quantize_int8_options_t quantize_int8_options, migraphx_program_parameters_t data);
migraphx_status migraphx_quantize_int8(migraphx_program_t prog, MIGRAPHX_C_EXPORT migraphx_status migraphx_quantize_int8(migraphx_program_t prog,
migraphx_target_t target, migraphx_target_t target,
migraphx_quantize_int8_options_t options); migraphx_quantize_int8_options_t options);
migraphx_status migraphx_context_finish(const_migraphx_context_t context); MIGRAPHX_C_EXPORT migraphx_status migraphx_context_finish(const_migraphx_context_t context);
migraphx_status migraphx_context_get_queue(void** out, migraphx_context_t context); MIGRAPHX_C_EXPORT migraphx_status migraphx_context_get_queue(void** out,
migraphx_context_t context);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status
migraphx_experimental_custom_op_destroy(migraphx_experimental_custom_op_t experimental_custom_op); migraphx_experimental_custom_op_destroy(migraphx_experimental_custom_op_t experimental_custom_op);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_experimental_custom_op_assign_to(
migraphx_experimental_custom_op_assign_to(migraphx_experimental_custom_op_t output, migraphx_experimental_custom_op_t output, const_migraphx_experimental_custom_op_t input);
const_migraphx_experimental_custom_op_t input);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status
migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experimental_custom_op, migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experimental_custom_op,
void* obj, void* obj,
migraphx_experimental_custom_op_copy c, migraphx_experimental_custom_op_copy c,
...@@ -518,21 +630,20 @@ migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experi ...@@ -518,21 +630,20 @@ migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experi
const char* obj_typename, const char* obj_typename,
const char* name); const char* name);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_experimental_custom_op_set_compute(
migraphx_experimental_custom_op_set_compute(migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_compute input);
migraphx_experimental_custom_op_compute input);
migraphx_status migraphx_experimental_custom_op_set_compute_shape( MIGRAPHX_C_EXPORT migraphx_status migraphx_experimental_custom_op_set_compute_shape(
migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_compute_shape input); migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_compute_shape input);
migraphx_status migraphx_experimental_custom_op_set_output_alias( MIGRAPHX_C_EXPORT migraphx_status migraphx_experimental_custom_op_set_output_alias(
migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_output_alias input); migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_output_alias input);
migraphx_status migraphx_experimental_custom_op_set_runs_on_offload_target( MIGRAPHX_C_EXPORT migraphx_status migraphx_experimental_custom_op_set_runs_on_offload_target(
migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_t obj,
migraphx_experimental_custom_op_runs_on_offload_target input); migraphx_experimental_custom_op_runs_on_offload_target input);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status
migraphx_experimental_custom_op_register(migraphx_experimental_custom_op_t experimental_custom_op); migraphx_experimental_custom_op_register(migraphx_experimental_custom_op_t experimental_custom_op);
#ifdef __cplusplus #ifdef __cplusplus
......
...@@ -571,10 +571,90 @@ using require_interface = ...@@ -571,10 +571,90 @@ using require_interface =
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_CONST_HANDLE_BASE(name) MIGRAPHX_DETAIL_HANDLE_BASE(name, const) #define MIGRAPHX_CONST_HANDLE_BASE(name) MIGRAPHX_DETAIL_HANDLE_BASE(name, const)
/**
* Container to hold optimal dynamic dimension values.
*/
struct optimals : MIGRAPHX_HANDLE_BASE(optimals)
{
MIGRAPHX_HANDLE_CONSTRUCTOR(optimals)
optimals(std::initializer_list<size_t> init_list)
{
this->make_handle(&migraphx_optimals_create, init_list.begin(), init_list.size());
}
};
/**
* @brief Dynamic dimension object.
* @details minimum, maximum, and optimal dimensions
*/
struct dynamic_dimension : MIGRAPHX_CONST_HANDLE_BASE(dynamic_dimension)
{
MIGRAPHX_HANDLE_CONSTRUCTOR(dynamic_dimension)
dynamic_dimension(size_t min, size_t max)
{
this->make_handle(&migraphx_dynamic_dimension_create_min_max, min, max);
}
dynamic_dimension(size_t min, size_t max, const optimals& opts)
{
this->make_handle(
&migraphx_dynamic_dimension_create_min_max_optimals, min, max, opts.get_handle_ptr());
}
bool is_fixed() const
{
bool result = false;
call(&migraphx_dynamic_dimension_is_fixed, &result, this->get_handle_ptr());
return result;
}
friend bool operator==(const dynamic_dimension& x, const dynamic_dimension& y)
{
bool pout;
call(&migraphx_dynamic_dimension_equal, &pout, x.get_handle_ptr(), y.get_handle_ptr());
return pout;
}
friend bool operator!=(const dynamic_dimension& x, const dynamic_dimension& y)
{
return not(x == y);
}
};
/**
* Container to hold dynamic_dimension objects.
*/
struct dynamic_dimensions : MIGRAPHX_HANDLE_BASE(dynamic_dimensions)
{
MIGRAPHX_HANDLE_CONSTRUCTOR(dynamic_dimensions)
template <class... Ts>
dynamic_dimensions(Ts... xs)
{
std::array<const_migraphx_dynamic_dimension_t, sizeof...(Ts)> a{xs.get_handle_ptr()...};
this->make_handle(&migraphx_dynamic_dimensions_create, a.data(), a.size());
}
size_t size() const
{
size_t pout;
call(&migraphx_dynamic_dimensions_size, &pout, this->get_handle_ptr());
return pout;
}
dynamic_dimension operator[](size_t pidx) const
{
const_migraphx_dynamic_dimension_t pout;
call(&migraphx_dynamic_dimensions_get, &pout, this->get_handle_ptr(), pidx);
return {pout, this->share_handle()};
}
};
/** /**
* @brief Describe shape of tensor * @brief Describe shape of tensor
* @details A shape consists of a data type, lengths of multi-dimension tensor, and strides * @details A shape consists of a data type, lengths of multi-dimension tensor, and strides
*
*/ */
struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape) struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
{ {
...@@ -598,6 +678,13 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape) ...@@ -598,6 +678,13 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
this->make_handle(&migraphx_shape_create, type, plengths.data(), plengths.size()); this->make_handle(&migraphx_shape_create, type, plengths.data(), plengths.size());
} }
// Force all calls of the format `shape( type_t, { size_t compatibles } )` to map to
// shape(type_t, std::vector<std::size_t> l)
shape(migraphx_shape_datatype_t t, std::initializer_list<std::size_t> d)
: shape::shape(t, std::vector<std::size_t>{d.begin(), d.end()})
{
}
shape(migraphx_shape_datatype_t type, shape(migraphx_shape_datatype_t type,
std::vector<size_t> plengths, std::vector<size_t> plengths,
std::vector<size_t> pstrides) std::vector<size_t> pstrides)
...@@ -610,6 +697,11 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape) ...@@ -610,6 +697,11 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
pstrides.size()); pstrides.size());
} }
shape(migraphx_shape_datatype_t type, const dynamic_dimensions& dyn_dims)
{
this->make_handle(&migraphx_shape_create_dynamic, type, dyn_dims.get_handle_ptr());
}
std::vector<size_t> lengths() const std::vector<size_t> lengths() const
{ {
const size_t* pout; const size_t* pout;
...@@ -626,6 +718,14 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape) ...@@ -626,6 +718,14 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
return {pout, pout + pout_size}; return {pout, pout + pout_size};
} }
/// Get the dynamic dimensions of the shape
dynamic_dimensions dyn_dims() const
{
migraphx_dynamic_dimensions_t pout;
call(&migraphx_shape_dyn_dims, &pout, this->get_handle_ptr());
return {pout, own{}};
}
migraphx_shape_datatype_t type() const migraphx_shape_datatype_t type() const
{ {
migraphx_shape_datatype_t pout; migraphx_shape_datatype_t pout;
...@@ -654,6 +754,14 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape) ...@@ -654,6 +754,14 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
return result; return result;
} }
/// Is the shape dynamic
bool dynamic() const
{
bool result = false;
call(&migraphx_shape_dynamic, &result, this->get_handle_ptr());
return result;
}
// map element index to space index // map element index to space index
size_t index(size_t i) const size_t index(size_t i) const
{ {
...@@ -687,6 +795,11 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument) ...@@ -687,6 +795,11 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.") MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.")
argument(const migraphx_argument* p) { this->set_handle(p, borrow{}); } argument(const migraphx_argument* p) { this->set_handle(p, borrow{}); }
argument(shape pshape)
{
this->make_handle(&migraphx_argument_create_empty, pshape.get_handle_ptr());
}
argument(shape pshape, void* pbuffer) argument(shape pshape, void* pbuffer)
{ {
this->make_handle(&migraphx_argument_create, pshape.get_handle_ptr(), pbuffer); this->make_handle(&migraphx_argument_create, pshape.get_handle_ptr(), pbuffer);
...@@ -1182,12 +1295,27 @@ struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options) ...@@ -1182,12 +1295,27 @@ struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options)
dim.size()); dim.size());
} }
void set_dyn_input_parameter_shape(const std::string& name, const dynamic_dimensions& dyn_dims)
{
call(&migraphx_onnx_options_set_dyn_input_parameter_shape,
this->get_handle_ptr(),
name.c_str(),
dyn_dims.get_handle_ptr());
}
/// When there is a dimension parameter, then use this default value /// When there is a dimension parameter, then use this default value
void set_default_dim_value(unsigned int value) void set_default_dim_value(unsigned int value)
{ {
call(&migraphx_onnx_options_set_default_dim_value, this->get_handle_ptr(), value); call(&migraphx_onnx_options_set_default_dim_value, this->get_handle_ptr(), value);
} }
void set_default_dyn_dim_value(const dynamic_dimension& dd)
{
call(&migraphx_onnx_options_set_default_dyn_dim_value,
this->get_handle_ptr(),
dd.get_handle_ptr());
}
/// Set default max iteration number for the loop operator /// Set default max iteration number for the loop operator
void set_default_loop_iterations(int64_t value) void set_default_loop_iterations(int64_t value)
{ {
...@@ -1359,13 +1487,17 @@ quantize_int8(const program& prog, const target& ptarget, const quantize_int8_op ...@@ -1359,13 +1487,17 @@ quantize_int8(const program& prog, const target& ptarget, const quantize_int8_op
struct experimental_custom_op_base struct experimental_custom_op_base
{ {
experimental_custom_op_base() = default;
experimental_custom_op_base(const experimental_custom_op_base&) = default;
experimental_custom_op_base& operator=(const experimental_custom_op_base&) = default;
virtual ~experimental_custom_op_base() = default;
virtual std::string name() const = 0; virtual std::string name() const = 0;
virtual argument compute(context ctx, shape output, arguments inputs) const = 0; virtual argument compute(context ctx, shape output, arguments inputs) const = 0;
virtual shape compute_shape(shapes inputs) const = 0; virtual shape compute_shape(shapes inputs) const = 0;
virtual std::vector<size_t> output_alias(shapes) const { return {}; } virtual std::vector<size_t> output_alias(shapes) const { return {}; }
// TODO: Return target string instead of bool // TODO: Return target string instead of bool
virtual bool runs_on_offload_target() const = 0; virtual bool runs_on_offload_target() const = 0;
virtual ~experimental_custom_op_base() = default;
}; };
struct experimental_custom_op : interface_base<MIGRAPHX_HANDLE_BASE(experimental_custom_op)> struct experimental_custom_op : interface_base<MIGRAPHX_HANDLE_BASE(experimental_custom_op)>
......
...@@ -45,56 +45,48 @@ def shape_type_wrap(p): ...@@ -45,56 +45,48 @@ def shape_type_wrap(p):
p.read = 'migraphx::to_shape_type(${name})' p.read = 'migraphx::to_shape_type(${name})'
@api.cwrap('migraphx::compile_options') def auto_handle(*args, **kwargs):
def compile_options_type_wrap(p): def with_handle(f):
if p.returns: return api.handle('migraphx_' + f.__name__, 'migraphx::' + f.__name__,
p.add_param('migraphx_compile_options *') *args, **kwargs)(f)
p.bad_param('${name} == nullptr', 'Null pointer')
p.write = ['*${name} = migraphx::to_compile_options(${result})']
else:
p.add_param('migraphx_compile_options *')
p.read = '${name} == nullptr ? migraphx::compile_options{} : migraphx::to_compile_options(*${name})'
@api.cwrap('migraphx::file_options')
def file_options_type_wrap(p):
if p.returns:
p.add_param('migraphx_file_options *')
p.bad_param('${name} == nullptr', 'Null pointer')
p.write = ['*${name} = migraphx::to_file_options(${result})']
else:
p.add_param('migraphx_file_options *')
p.read = '${name} == nullptr ? migraphx::file_options{} : migraphx::to_file_options(*${name})'
return with_handle
@api.cwrap('migraphx::onnx_options')
def onnx_options_type_wrap(p):
if p.returns:
p.add_param('migraphx_onnx_options *')
p.bad_param('${name} == nullptr', 'Null pointer')
p.write = ['*${name} = migraphx::to_onnx_options(${result})']
else:
p.add_param('migraphx_onnx_options *')
p.read = '${name} == nullptr ? migraphx::onnx_options{} : migraphx::to_onnx_options(*${name})'
@api.handle('migraphx_optimals', 'std::set<size_t>')
def optimals(h):
h.constructor('create',
api.params(ptr='const size_t*', size='size_t'),
fname='migraphx::make_set<size_t>')
@api.cwrap('migraphx::tf_options')
def tf_options_type_wrap(p):
if p.returns:
p.add_param('migraphx_tf_options *')
p.bad_param('${name} == nullptr', 'Null pointer')
p.write = ['*${name} = migraphx::to_tf_options(${result})']
else:
p.add_param('migraphx_tf_options *')
p.read = '${name} == nullptr ? migraphx::tf_options{} : migraphx::to_tf_options(*${name})'
@api.handle('migraphx_dynamic_dimension', 'migraphx::shape::dynamic_dimension')
def dynamic_dimension(h):
h.constructor('create_min_max', api.params(min='size_t', max='size_t'))
h.constructor(
'create_min_max_optimals',
api.params(min='size_t', max='size_t', optimals='std::set<size_t>'))
h.method('is_fixed', returns='bool', const=True)
h.method('equal',
api.params(x='const migraphx::shape::dynamic_dimension&'),
invoke='migraphx::equal($@)',
returns='bool',
const=True)
def auto_handle(*args, **kwargs):
def with_handle(f):
return api.handle('migraphx_' + f.__name__, 'migraphx::' + f.__name__,
*args, **kwargs)(f)
return with_handle @api.handle('migraphx_dynamic_dimensions',
'std::vector<migraphx::shape::dynamic_dimension>')
def dynamic_dimensions(h):
h.constructor(
'create',
api.params(ptr='const_migraphx_dynamic_dimension_t*', size='size_t'),
fname='migraphx::to_obj_vector<const_migraphx_dynamic_dimension_t>')
h.method('size', returns='size_t')
h.method('get',
api.params(idx='size_t'),
fname='at',
cpp_name='operator[]',
returns='const migraphx::shape::dynamic_dimension&')
@auto_handle() @auto_handle()
...@@ -109,20 +101,29 @@ def shape(h): ...@@ -109,20 +101,29 @@ def shape(h):
lengths='std::vector<size_t>', lengths='std::vector<size_t>',
strides='std::vector<size_t>')) strides='std::vector<size_t>'))
h.constructor('create_scalar', api.params(type='migraphx::shape::type_t')) h.constructor('create_scalar', api.params(type='migraphx::shape::type_t'))
h.constructor(
'create_dynamic',
api.params(type='migraphx::shape::type_t',
dims='std::vector<migraphx::shape::dynamic_dimension>'))
h.method('lengths', h.method('lengths',
fname='lens', fname='lens',
returns='const std::vector<size_t>&', returns='const std::vector<size_t>&',
const=True) const=True)
h.method('strides', returns='const std::vector<size_t>&', const=True) h.method('strides', returns='const std::vector<size_t>&', const=True)
h.method('dyn_dims',
returns='std::vector<migraphx::shape::dynamic_dimension>',
const=True)
h.method('type', returns='migraphx::shape::type_t', const=True) h.method('type', returns='migraphx::shape::type_t', const=True)
h.method('elements', returns='size_t', const=True) h.method('elements', returns='size_t', const=True)
h.method('bytes', returns='size_t', const=True) h.method('bytes', returns='size_t', const=True)
h.method('ndim', returns='size_t', const=True)
h.method('equal', h.method('equal',
api.params(x='const migraphx::shape&'), api.params(x='const migraphx::shape&'),
invoke='migraphx::equal($@)', invoke='migraphx::equal($@)',
returns='bool', returns='bool',
const=True) const=True)
h.method('standard', returns='bool', const=True) h.method('standard', returns='bool', const=True)
h.method('dynamic', returns='bool', const=True)
h.method('index', api.params(i='size_t'), returns='size_t', const=True) h.method('index', api.params(i='size_t'), returns='size_t', const=True)
...@@ -130,6 +131,7 @@ def shape(h): ...@@ -130,6 +131,7 @@ def shape(h):
def argument(h): def argument(h):
h.constructor('create', h.constructor('create',
api.params(shape='const migraphx::shape&', buffer='void*')) api.params(shape='const migraphx::shape&', buffer='void*'))
h.constructor('create_empty', api.params(shape='const migraphx::shape&'))
h.method('shape', h.method('shape',
fname='get_shape', fname='get_shape',
cpp_name='get_shape', cpp_name='get_shape',
...@@ -325,11 +327,22 @@ def onnx_options(h): ...@@ -325,11 +327,22 @@ def onnx_options(h):
api.params(name='const char*', dims='std::vector<size_t>'), api.params(name='const char*', dims='std::vector<size_t>'),
invoke='migraphx::set_input_parameter_shape($@)', invoke='migraphx::set_input_parameter_shape($@)',
) )
h.method(
'set_dyn_input_parameter_shape',
api.params(name='const char*',
dims='std::vector<migraphx::shape::dynamic_dimension>'),
invoke='migraphx::set_dyn_input_parameter_shape($@)',
)
h.method( h.method(
'set_default_dim_value', 'set_default_dim_value',
api.params(value='size_t'), api.params(value='size_t'),
invoke='migraphx::set_default_dim_value($@)', invoke='migraphx::set_default_dim_value($@)',
) )
h.method(
'set_default_dyn_dim_value',
api.params(dd='const migraphx::shape::dynamic_dimension&'),
invoke='migraphx::set_default_dyn_dim_value($@)',
)
h.method( h.method(
'set_default_loop_iterations', 'set_default_loop_iterations',
api.params(value='int64_t'), api.params(value='int64_t'),
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -31,20 +31,6 @@ ...@@ -31,20 +31,6 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
// Example:
// s0 = (3,2,4,5) and s1 = (2,1,1)
//
// In this case we need to broadcast (:,1,1) portion of
// s1 plus broadcast the 1st dimension of s1
// giving output_lens = (3,2,4,5)
//
// Another example:
// s0 = (3,2,1,5) and s1 = (2,7,5)
// In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5)
//
std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0, std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
std::vector<std::size_t> s1) std::vector<std::size_t> s1)
{ {
...@@ -77,32 +63,38 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha ...@@ -77,32 +63,38 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha
} }
auto offset = s1.ndim() - s0.ndim(); auto offset = s1.ndim() - s0.ndim();
std::vector<shape::dynamic_dimension> out_dims(s1.dyn_dims()); std::vector<shape::dynamic_dimension> out_dims(s1.dyn_dims());
std::transform( std::transform(s0.dyn_dims().cbegin(),
s0.dyn_dims().cbegin(), s0.dyn_dims().cend(),
s0.dyn_dims().cend(), s1.dyn_dims().cbegin() + offset,
s1.dyn_dims().cbegin() + offset, out_dims.begin() + offset,
out_dims.begin() + offset, [&](auto a, auto b) {
[&](auto a, auto b) { if(a == b or b == 1)
if(a == b) {
{ return a;
return a; }
} else if(a == 1)
else if(a == 1 or b == 1) {
{ return b;
// setting optimals to empty, may need to be changed }
return shape::dynamic_dimension{std::max(a.min, b.min), std::max(a.max, b.max)}; else
} {
else MIGRAPHX_THROW("COMPUTE_BROADCASTED_DYN_DIMS: dynamic shapes {" +
{ migraphx::to_string_range(s0.dyn_dims()) + "} and {" +
MIGRAPHX_THROW("COMPUTE_BROADCASTED_DYN_DIMS: dynamic shapes {" + migraphx::to_string_range(s1.dyn_dims()) + "} mismatch!");
migraphx::to_string_range(s0.dyn_dims()) + "} and {" + }
migraphx::to_string_range(s1.dyn_dims()) + "} mismatch!"); });
}
});
return out_dims; return out_dims;
} }
// Compute the common (broadcasted) dimensions of a list of fixed shapes std::vector<shape::dynamic_dimension> compute_common_dyn_dims(const std::vector<shape>& shapes)
{
auto ret_shape = shapes.at(0);
std::for_each(shapes.cbegin() + 1, shapes.cend(), [&](auto s) {
ret_shape = shape{ret_shape.type(), compute_broadcasted_dyn_dims(ret_shape, s)};
});
return ret_shape.dyn_dims();
}
std::vector<std::size_t> compute_common_lens(const std::vector<shape>& shapes) std::vector<std::size_t> compute_common_lens(const std::vector<shape>& shapes)
{ {
assert(not shapes.empty()); assert(not shapes.empty());
...@@ -154,34 +146,30 @@ insert_common_args(module& m, instruction_ref ins, std::vector<instruction_ref> ...@@ -154,34 +146,30 @@ insert_common_args(module& m, instruction_ref ins, std::vector<instruction_ref>
if(std::any_of( if(std::any_of(
inputs.cbegin(), inputs.cend(), [](auto input) { return input->get_shape().dynamic(); })) inputs.cbegin(), inputs.cend(), [](auto input) { return input->get_shape().dynamic(); }))
{ {
// currently only handles the binary case auto input_shapes = to_shapes(inputs);
if(inputs.size() != 2) auto c_type = compute_common_types(input_shapes);
{ auto c_dyn_dims = compute_common_dyn_dims(input_shapes);
MIGRAPHX_THROW("INSERT_COMMON_OP: not handled; " + migraphx::to_string(inputs.size()) +
"inputs, only handle two inputs if any are dynamic shape");
}
auto c_type = compute_common_types(to_shapes(inputs));
auto c_dyn_dims =
compute_broadcasted_dyn_dims(inputs[0]->get_shape(), inputs[1]->get_shape());
// following should work for a static or dynamic shape auto s0 = inputs[0]->get_shape();
if(inputs[0]->get_shape().dyn_dims() != c_dyn_dims) if(not s0.dynamic() or s0.dyn_dims() != c_dyn_dims)
{ {
inputs[0] = m.insert_instruction( inputs[0] = m.insert_instruction(
ins, ins, make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}), inputs);
make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}),
inputs[0],
inputs[1]);
}
if(inputs[1]->get_shape().dyn_dims() != c_dyn_dims)
{
inputs[1] = m.insert_instruction(
ins,
make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}),
inputs[1],
inputs[0]);
} }
std::transform(inputs.begin() + 1, inputs.end(), inputs.begin() + 1, [&](auto input) {
// uses previous input to avoid recalculating the common shape from the
// full set of input shapes at runtime
auto s = input->get_shape();
if(not s.dynamic() or s.dyn_dims() != c_dyn_dims)
{
return m.insert_instruction(
ins,
make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}),
input,
inputs[0]);
}
return input;
});
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) { std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) {
if(input->get_shape().type() != c_type) if(input->get_shape().type() != c_type)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment