Commit 0369e974 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

Merge branch 'batch_report' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into mi100_opts

parents 3a474fca d70fd0df
...@@ -80,10 +80,10 @@ jobs: ...@@ -80,10 +80,10 @@ jobs:
uses: pat-s/always-upload-cache@v2.1.3 uses: pat-s/always-upload-cache@v2.1.3
with: with:
path: cppcheck-cache path: cppcheck-cache
key: cppcheck-cache-${{ steps.cache_timestamp.outputs.timestamp }} key: cppcheck-cache-${{ hashFiles('cppcheck.rules', 'CMakeLists.txt') }}-${{ steps.cache_timestamp.outputs.timestamp }}
restore-keys: | restore-keys: |
cppcheck-cache-${{ steps.cache_timestamp.outputs.timestamp }} cppcheck-cache-${{ hashFiles('cppcheck.rules', 'CMakeLists.txt') }}-${{ steps.cache_timestamp.outputs.timestamp }}
cppcheck-cache- cppcheck-cache-${{ hashFiles('cppcheck.rules', 'CMakeLists.txt') }}-
- name: Build the Docker image - name: Build the Docker image
run: docker build . --file hip-clang.docker --tag migraphx run: docker build . --file hip-clang.docker --tag migraphx
......
...@@ -36,7 +36,7 @@ find_package(nlohmann_json 3.8.0 REQUIRED) ...@@ -36,7 +36,7 @@ find_package(nlohmann_json 3.8.0 REQUIRED)
include(ROCMSetupVersion) include(ROCMSetupVersion)
rocm_setup_version(VERSION 1.3) rocm_setup_version(VERSION 2.0)
set(MIGRAPHX_SO_VERSION ${PROJECT_VERSION_MAJOR}.${PROJECT_VERSION_MINOR}) set(MIGRAPHX_SO_VERSION ${PROJECT_VERSION_MAJOR}.${PROJECT_VERSION_MINOR})
option( BUILD_SHARED_LIBS "Build as a shared library" ON ) option( BUILD_SHARED_LIBS "Build as a shared library" ON )
...@@ -190,6 +190,7 @@ rocm_enable_cppcheck( ...@@ -190,6 +190,7 @@ rocm_enable_cppcheck(
shadowVariable shadowVariable
unsafeClassDivZero unsafeClassDivZero
definePrefix:*test/include/test.hpp definePrefix:*test/include/test.hpp
ctuOneDefinitionRuleViolation:*test/*
useSmartPointer:*src/api/api.cpp useSmartPointer:*src/api/api.cpp
useSmartPointer:*make_shared_array.hpp useSmartPointer:*make_shared_array.hpp
constParameter:*src/targets/gpu/*.cpp constParameter:*src/targets/gpu/*.cpp
......
...@@ -6,7 +6,7 @@ ARG PREFIX=/usr/local ...@@ -6,7 +6,7 @@ ARG PREFIX=/usr/local
RUN dpkg --add-architecture i386 RUN dpkg --add-architecture i386
# Add rocm repository # Add rocm repository
RUN sh -c 'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/4.2/ xenial main > /etc/apt/sources.list.d/rocm.list' RUN sh -c 'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/4.5/ ubuntu main > /etc/apt/sources.list.d/rocm.list'
# Install dependencies # Install dependencies
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \
...@@ -32,6 +32,8 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- ...@@ -32,6 +32,8 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
software-properties-common \ software-properties-common \
wget \ wget \
rocm-device-libs \ rocm-device-libs \
hip-base \
libnuma-dev \
miopen-hip \ miopen-hip \
rocblas \ rocblas \
zlib1g-dev && \ zlib1g-dev && \
......
...@@ -20,7 +20,7 @@ def rocmtestnode(Map conf) { ...@@ -20,7 +20,7 @@ def rocmtestnode(Map conf) {
rm -rf build rm -rf build
mkdir build mkdir build
cd build cd build
CXX=${compiler} CXXFLAGS='-Werror -Wno-fallback' cmake -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache ${flags} .. CXX=${compiler} CXXFLAGS='-Werror' cmake -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache ${flags} ..
make -j\$(nproc) generate all doc package check VERBOSE=1 make -j\$(nproc) generate all doc package check VERBOSE=1
""" """
echo cmd echo cmd
...@@ -75,6 +75,8 @@ def rocmnodename(name) { ...@@ -75,6 +75,8 @@ def rocmnodename(name) {
node_name = "${rocmtest_name} && fiji"; node_name = "${rocmtest_name} && fiji";
} else if(name == "vega") { } else if(name == "vega") {
node_name = "${rocmtest_name} && vega"; node_name = "${rocmtest_name} && vega";
} else if(name == "navi21") {
node_name = "${rocmtest_name} && navi21";
} else if(name == "nogpu") { } else if(name == "nogpu") {
return rocmtest_name; return rocmtest_name;
} }
...@@ -110,6 +112,10 @@ rocmtest clang_debug: rocmnode('vega') { cmake_build -> ...@@ -110,6 +112,10 @@ rocmtest clang_debug: rocmnode('vega') { cmake_build ->
def debug_flags = "-g -O2 -fno-omit-frame-pointer -fsanitize=${sanitizers} -fno-sanitize-recover=${sanitizers}" def debug_flags = "-g -O2 -fno-omit-frame-pointer -fsanitize=${sanitizers} -fno-sanitize-recover=${sanitizers}"
cmake_build("/opt/rocm/llvm/bin/clang++", "-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_PYTHON=Off -DMIGRAPHX_ENABLE_GPU=Off -DMIGRAPHX_ENABLE_CPU=On -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}'") cmake_build("/opt/rocm/llvm/bin/clang++", "-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_PYTHON=Off -DMIGRAPHX_ENABLE_GPU=Off -DMIGRAPHX_ENABLE_CPU=On -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}'")
} }
}, clang_release_navi: rocmnode('navi21') { cmake_build ->
stage('HIP Clang Release Navi') {
cmake_build("/opt/rocm/llvm/bin/clang++", "-DCMAKE_BUILD_TYPE=release")
}
} }
def onnxnode(name, body) { def onnxnode(name, body) {
......
...@@ -152,6 +152,24 @@ ...@@ -152,6 +152,24 @@
<summary>Else statement is not necessary.</summary> <summary>Else statement is not necessary.</summary>
</message> </message>
</rule> </rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[((?:(?:\w+|<|>|::) )*(?:\w+|>)(?: &|\*)*) (\w+) ; \2 = static_cast < \1 > (\([^()]*(?-1)*[^()]*\)) ;]]></pattern>
<message>
<id>RedundantCast</id>
<severity>style</severity>
<summary>Static cast is redundant.</summary>
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[auto (\w+) ; \1 = static_cast < (?:(?:\w+|<|>|::) )*(?:\w+|>)(?: &|\*)* > (\([^()]*(?-1)*[^()]*\)) ;]]></pattern>
<message>
<id>RedundantCast</id>
<severity>style</severity>
<summary>Static cast is redundant.</summary>
</message>
</rule>
<rule> <rule>
<tokenlist>normal</tokenlist> <tokenlist>normal</tokenlist>
<pattern><![CDATA[\? (true|false) : (true|false)]]></pattern> <pattern><![CDATA[\? (true|false) : (true|false)]]></pattern>
......
sphinx==2.2.2 docutils==0.17.1
breathe==4.13.1 sphinx==4.2.0
breathe==4.31.0
sphinx_rtd_theme==1.0.0
# git+https://github.com/arximboldi/breathe@fix-node-parent # git+https://github.com/arximboldi/breathe@fix-node-parent
...@@ -18,6 +18,8 @@ ...@@ -18,6 +18,8 @@
# #
# import os # import os
# import sys # import sys
from datetime import date
import re
# sys.path.insert(0, os.path.abspath('.')) # sys.path.insert(0, os.path.abspath('.'))
# -- General configuration ------------------------------------------------ # -- General configuration ------------------------------------------------
...@@ -29,7 +31,9 @@ ...@@ -29,7 +31,9 @@
# Add any Sphinx extension module names here, as strings. They can be # Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones. # ones.
extensions = ['breathe', 'sphinx.ext.mathjax', 'sphinx.ext.viewcode'] extensions = [
'breathe', 'sphinx.ext.mathjax', 'sphinx.ext.viewcode', 'sphinx_rtd_theme'
]
# Add any paths that contain templates here, relative to this directory. # Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates'] templates_path = ['_templates']
...@@ -45,7 +49,7 @@ master_doc = 'index' ...@@ -45,7 +49,7 @@ master_doc = 'index'
# General information about the project. # General information about the project.
project = u'MIGraphX' project = u'MIGraphX'
copyright = u'2018, AMD' copyright = u'2018-{}, AMD'.format(date.today().year)
author = u'AMD' author = u'AMD'
# The version info for the project you're documenting, acts as replacement for # The version info for the project you're documenting, acts as replacement for
...@@ -53,9 +57,12 @@ author = u'AMD' ...@@ -53,9 +57,12 @@ author = u'AMD'
# built documents. # built documents.
# #
# The short X.Y version. # The short X.Y version.
version = u'0.1' with open('../../CMakeLists.txt') as file:
version = next((re.findall('[0-9.]+', line)[0]
for line in file.readlines()
if 'rocm_setup_version' in line))
# The full version, including alpha/beta/rc tags. # The full version, including alpha/beta/rc tags.
release = u'0.1' release = version
# The language for content autogenerated by Sphinx. Refer to documentation # The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages. # for a list of supported languages.
...@@ -82,7 +89,7 @@ todo_include_todos = False ...@@ -82,7 +89,7 @@ todo_include_todos = False
# The theme to use for HTML and HTML Help pages. See the documentation for # The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes. # a list of builtin themes.
# #
html_theme = 'alabaster' html_theme = 'sphinx_rtd_theme'
# Theme options are theme-specific and customize the look and feel of a theme # Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the # further. For a list of options available for each theme, see the
......
...@@ -25,8 +25,8 @@ migraphx::save(p, output_file); ...@@ -25,8 +25,8 @@ migraphx::save(p, output_file);
``` ```
migraphx::program p = ... <migraphx::program>; migraphx::program p = ... <migraphx::program>;
migraphx_file_options options; migraphx::file_options options;
options.format = "msgpack"; options.set_file_format("msgpack");
migraphx::save(p, output_file, options); migraphx::save(p, output_file, options);
``` ```
...@@ -41,15 +41,15 @@ p = migraphx::load(input_file); ...@@ -41,15 +41,15 @@ p = migraphx::load(input_file);
``` ```
migraphx::program p; migraphx::program p;
migraphx_file_options options; migraphx::file_options options;
options.format = "msgpack"; options.set_file_format("msgpack");
p = migraphx::load(input_file, options); p = migraphx::load(input_file, options);
``` ```
To load a program that has been saved in JSON format: To load a program that has been saved in JSON format:
``` ```
migraphx::program p; migraphx::program p;
migraphx_file_options options; migraphx::file_options options;
options.format = "json"; options.set_file_format("json");
p = migraphx::load(input_file, options); p = migraphx::load(input_file, options);
``` ```
......
...@@ -44,15 +44,15 @@ int main(int argc, char** argv) ...@@ -44,15 +44,15 @@ int main(int argc, char** argv)
std::string format = load_arg; std::string format = load_arg;
if(format == "json") if(format == "json")
{ {
migraphx_file_options options; migraphx::file_options options;
options.format = "json"; options.set_file_format("json");
p = migraphx::load(input_file, options); p = migraphx::load(input_file, options);
} }
else if(format == "msgpack") else if(format == "msgpack")
{ {
migraphx_file_options options; migraphx::file_options options;
options.format = "msgpack"; options.set_file_format("msgpack");
p = migraphx::load(input_file, options); p = migraphx::load(input_file, options);
} }
else else
p = migraphx::load(input_file); p = migraphx::load(input_file);
...@@ -80,8 +80,8 @@ int main(int argc, char** argv) ...@@ -80,8 +80,8 @@ int main(int argc, char** argv)
output_file = save_arg == nullptr ? "out" : save_arg; output_file = save_arg == nullptr ? "out" : save_arg;
output_file.append(".msgpack"); output_file.append(".msgpack");
migraphx_file_options options; migraphx::file_options options;
options.format = "msgpack"; options.set_file_format("msgpack");
migraphx::save(p, output_file.c_str(), options); migraphx::save(p, output_file.c_str(), options);
std::cout << "Program has been saved as ./" << output_file << std::endl; std::cout << "Program has been saved as ./" << output_file << std::endl;
} }
......
...@@ -60,14 +60,14 @@ migraphx::quantize_int8(prog, targ, quant_opts); ...@@ -60,14 +60,14 @@ migraphx::quantize_int8(prog, targ, quant_opts);
## Compilation ## Compilation
Network graphs saved in e.g. ONNX or protobuf format are not target-specific. In order to run inference, we must compile the graph into a target-specific program. Network graphs saved in e.g. ONNX or protobuf format are not target-specific. In order to run inference, we must compile the graph into a target-specific program.
Two options may be turned on (default for both is `false`) when compiling: Two options may be turned on when compiling:
- `bool offload_copy`: For targets with offloaded memory (such as the gpu), this will insert instructions during compilation to copy the input parameters to the offloaded memory and to copy the final result from the offloaded memory back to main memory. - `set_offload_copy(bool value)`: For targets with offloaded memory (such as the gpu), this will insert instructions during compilation to copy the input parameters to the offloaded memory and to copy the final result from the offloaded memory back to main memory. Default value is `false` for offload_copy.
- `bool fast_math`: Optimize math functions to use faster approximate versions. There may be slight accuracy degredation when enabled. - `set_fast_math(bool value)`: Optimize math functions to use faster approximate versions. There may be slight accuracy degredation when enabled. Default value is `true` for fast_math.
The following snippet assumes `targ` has been set as "gpu", and will compile the program without the fast_math optimization. The following snippet assumes `targ` has been set as "gpu", and will compile the program without the fast_math optimization.
``` ```
migraphx_compile_options comp_opts; migraphx::compile_options comp_opts;
comp_opts.offload_copy = true; comp_opts.set_offload_copy();
prog.compile(targ, comp_opts); prog.compile(targ, comp_opts);
``` ```
......
...@@ -99,8 +99,8 @@ int main(int argc, char** argv) ...@@ -99,8 +99,8 @@ int main(int argc, char** argv)
if(GPU) if(GPU)
{ {
migraphx_compile_options comp_opts; migraphx::compile_options comp_opts;
comp_opts.offload_copy = true; comp_opts.set_offload_copy();
prog.compile(targ, comp_opts); prog.compile(targ, comp_opts);
} }
else else
......
...@@ -6,7 +6,7 @@ ARG PREFIX=/usr/local ...@@ -6,7 +6,7 @@ ARG PREFIX=/usr/local
RUN dpkg --add-architecture i386 RUN dpkg --add-architecture i386
# Add rocm repository # Add rocm repository
RUN sh -c 'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/4.2/ xenial main > /etc/apt/sources.list.d/rocm.list' RUN sh -c 'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/4.5/ ubuntu main > /etc/apt/sources.list.d/rocm.list'
# Install dependencies # Install dependencies
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \
...@@ -29,6 +29,8 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- ...@@ -29,6 +29,8 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
software-properties-common \ software-properties-common \
wget \ wget \
rocm-device-libs \ rocm-device-libs \
hip-base \
libnuma-dev \
miopen-hip \ miopen-hip \
rocblas \ rocblas \
zlib1g-dev && \ zlib1g-dev && \
......
...@@ -7,6 +7,7 @@ include(CheckCXXLinkerFlag) ...@@ -7,6 +7,7 @@ include(CheckCXXLinkerFlag)
add_library(migraphx add_library(migraphx
adjust_allocation.cpp adjust_allocation.cpp
analyze_streams.cpp analyze_streams.cpp
apply_alpha_beta.cpp
argument.cpp argument.cpp
auto_contiguous.cpp auto_contiguous.cpp
common.cpp common.cpp
...@@ -14,7 +15,6 @@ add_library(migraphx ...@@ -14,7 +15,6 @@ add_library(migraphx
convert_to_json.cpp convert_to_json.cpp
cpp_generator.cpp cpp_generator.cpp
dead_code_elimination.cpp dead_code_elimination.cpp
decompose.cpp
dom_info.cpp dom_info.cpp
dynamic_loader.cpp dynamic_loader.cpp
eliminate_allocation.cpp eliminate_allocation.cpp
...@@ -26,6 +26,7 @@ add_library(migraphx ...@@ -26,6 +26,7 @@ add_library(migraphx
eliminate_pad.cpp eliminate_pad.cpp
env.cpp env.cpp
file_buffer.cpp file_buffer.cpp
fuse_pointwise.cpp
generate.cpp generate.cpp
inline_module.cpp inline_module.cpp
insert_pad.cpp insert_pad.cpp
...@@ -52,7 +53,6 @@ add_library(migraphx ...@@ -52,7 +53,6 @@ add_library(migraphx
reduce_dims.cpp reduce_dims.cpp
register_op.cpp register_op.cpp
register_target.cpp register_target.cpp
remap.cpp
simplify_qdq.cpp simplify_qdq.cpp
rewrite_batchnorm.cpp rewrite_batchnorm.cpp
rewrite_pooling.cpp rewrite_pooling.cpp
...@@ -131,8 +131,11 @@ register_migraphx_ops( ...@@ -131,8 +131,11 @@ register_migraphx_ops(
multibroadcast multibroadcast
multinomial multinomial
neg neg
nonmaxsuppression
nonzero
outline outline
pad pad
pointwise
pooling pooling
pow pow
prefix_scan_sum prefix_scan_sum
...@@ -153,6 +156,7 @@ register_migraphx_ops( ...@@ -153,6 +156,7 @@ register_migraphx_ops(
rnn_last_cell_output rnn_last_cell_output
rnn_last_hs_output rnn_last_hs_output
rnn_var_sl_last_output rnn_var_sl_last_output
roialign
round round
rsqrt rsqrt
scalar scalar
...@@ -198,6 +202,9 @@ target_link_libraries(migraphx PRIVATE -ldl) ...@@ -198,6 +202,9 @@ target_link_libraries(migraphx PRIVATE -ldl)
target_include_directories(migraphx SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>) target_include_directories(migraphx SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
find_package(Threads)
target_link_libraries(migraphx PUBLIC Threads::Threads)
find_package(msgpack REQUIRED) find_package(msgpack REQUIRED)
target_link_libraries(migraphx PRIVATE msgpackc-cxx) target_link_libraries(migraphx PRIVATE msgpackc-cxx)
# Make this available to the tests # Make this available to the tests
...@@ -235,6 +242,7 @@ rocm_export_targets( ...@@ -235,6 +242,7 @@ rocm_export_targets(
TARGETS migraphx::migraphx migraphx_all_targets TARGETS migraphx::migraphx migraphx_all_targets
NAMESPACE migraphx:: NAMESPACE migraphx::
DEPENDS DEPENDS
Threads
${PACKAGE_DEPENDS} ${PACKAGE_DEPENDS}
) )
......
...@@ -3,7 +3,7 @@ add_library(migraphx_c ...@@ -3,7 +3,7 @@ add_library(migraphx_c
api.cpp api.cpp
) )
set_target_properties(migraphx_c PROPERTIES EXPORT_NAME c) set_target_properties(migraphx_c PROPERTIES EXPORT_NAME c)
rocm_set_soversion(migraphx_c 2.0) rocm_set_soversion(migraphx_c 3.0)
rocm_clang_tidy_check(migraphx_c) rocm_clang_tidy_check(migraphx_c)
target_link_libraries(migraphx_c PRIVATE migraphx migraphx_tf migraphx_onnx migraphx_all_targets) target_link_libraries(migraphx_c PRIVATE migraphx migraphx_tf migraphx_onnx migraphx_all_targets)
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include <migraphx/json.hpp> #include <migraphx/json.hpp>
#include <migraphx/convert_to_json.hpp> #include <migraphx/convert_to_json.hpp>
#include <algorithm> #include <algorithm>
#include <cstdarg>
namespace migraphx { namespace migraphx {
...@@ -155,18 +156,30 @@ void quantize_int8_wrap(program& prog, const target& t, quantize_int8_options& o ...@@ -155,18 +156,30 @@ void quantize_int8_wrap(program& prog, const target& t, quantize_int8_options& o
migraphx::quantize_int8(prog, t, options.calibration, options.op_names); migraphx::quantize_int8(prog, t, options.calibration, options.op_names);
} }
operation create_op(const char* name, const char* attributes) #ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wformat-nonliteral"
#endif
operation create_op(const char* name, const char* attributes, va_list vlist)
{ {
std::string sattributes = attributes == nullptr ? "" : attributes;
std::vector<char> buffer(sattributes.size() * 2);
std::vsnprintf(buffer.data(), buffer.size(), sattributes.c_str(), vlist);
value v = value::object{}; value v = value::object{};
if(attributes != nullptr) if(attributes != nullptr)
{ {
v = from_json_string(convert_to_json(std::string(attributes))); v = from_json_string(convert_to_json(std::string(buffer.data())));
} }
auto op = make_op(name, v); auto op = make_op(name, v);
return op; return op;
} }
#ifdef __clang__
#pragma clang diagnostic pop
#endif
template <class T> template <class T>
bool equal(const T& x, const T& y) bool equal(const T& x, const T& y)
{ {
...@@ -368,7 +381,8 @@ struct migraphx_quantize_int8_options ...@@ -368,7 +381,8 @@ struct migraphx_quantize_int8_options
extern "C" migraphx_status migraphx_shape_destroy(migraphx_shape_t shape) extern "C" migraphx_status migraphx_shape_destroy(migraphx_shape_t shape)
{ {
return migraphx::try_([&] { destroy((shape)); }); auto api_error_result = migraphx::try_([&] { destroy((shape)); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_shape_create(migraphx_shape_t* shape, extern "C" migraphx_status migraphx_shape_create(migraphx_shape_t* shape,
...@@ -376,13 +390,14 @@ extern "C" migraphx_status migraphx_shape_create(migraphx_shape_t* shape, ...@@ -376,13 +390,14 @@ extern "C" migraphx_status migraphx_shape_create(migraphx_shape_t* shape,
size_t* lengths, size_t* lengths,
size_t lengths_size) size_t lengths_size)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(lengths == nullptr and lengths_size != 0) if(lengths == nullptr and lengths_size != 0)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter lengths: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter lengths: Null pointer");
*shape = object_cast<migraphx_shape_t>( *shape = object_cast<migraphx_shape_t>(
allocate<migraphx::shape>((migraphx::to_shape_type(type)), allocate<migraphx::shape>((migraphx::to_shape_type(type)),
(std::vector<size_t>(lengths, lengths + lengths_size)))); (std::vector<size_t>(lengths, lengths + lengths_size))));
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_shape_create_with_strides(migraphx_shape_t* shape, extern "C" migraphx_status migraphx_shape_create_with_strides(migraphx_shape_t* shape,
...@@ -392,7 +407,7 @@ extern "C" migraphx_status migraphx_shape_create_with_strides(migraphx_shape_t* ...@@ -392,7 +407,7 @@ extern "C" migraphx_status migraphx_shape_create_with_strides(migraphx_shape_t*
size_t* strides, size_t* strides,
size_t strides_size) size_t strides_size)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(lengths == nullptr and lengths_size != 0) if(lengths == nullptr and lengths_size != 0)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter lengths: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter lengths: Null pointer");
if(strides == nullptr and strides_size != 0) if(strides == nullptr and strides_size != 0)
...@@ -402,21 +417,23 @@ extern "C" migraphx_status migraphx_shape_create_with_strides(migraphx_shape_t* ...@@ -402,21 +417,23 @@ extern "C" migraphx_status migraphx_shape_create_with_strides(migraphx_shape_t*
(std::vector<size_t>(lengths, lengths + lengths_size)), (std::vector<size_t>(lengths, lengths + lengths_size)),
(std::vector<size_t>(strides, strides + strides_size)))); (std::vector<size_t>(strides, strides + strides_size))));
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_shape_create_scalar(migraphx_shape_t* shape, extern "C" migraphx_status migraphx_shape_create_scalar(migraphx_shape_t* shape,
migraphx_shape_datatype_t type) migraphx_shape_datatype_t type)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
*shape = object_cast<migraphx_shape_t>( *shape = object_cast<migraphx_shape_t>(
allocate<migraphx::shape>((migraphx::to_shape_type(type)))); allocate<migraphx::shape>((migraphx::to_shape_type(type))));
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_shape_lengths(const size_t** out, size_t* out_size, const_migraphx_shape_t shape) migraphx_shape_lengths(const size_t** out, size_t* out_size, const_migraphx_shape_t shape)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(out == nullptr or out_size == nullptr) if(out == nullptr or out_size == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter out: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter out: Null pointer");
if(shape == nullptr) if(shape == nullptr)
...@@ -425,12 +442,13 @@ migraphx_shape_lengths(const size_t** out, size_t* out_size, const_migraphx_shap ...@@ -425,12 +442,13 @@ migraphx_shape_lengths(const size_t** out, size_t* out_size, const_migraphx_shap
*out = api_result.data(); *out = api_result.data();
*out_size = api_result.size(); *out_size = api_result.size();
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_shape_strides(const size_t** out, size_t* out_size, const_migraphx_shape_t shape) migraphx_shape_strides(const size_t** out, size_t* out_size, const_migraphx_shape_t shape)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(out == nullptr or out_size == nullptr) if(out == nullptr or out_size == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter out: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter out: Null pointer");
if(shape == nullptr) if(shape == nullptr)
...@@ -439,127 +457,141 @@ migraphx_shape_strides(const size_t** out, size_t* out_size, const_migraphx_shap ...@@ -439,127 +457,141 @@ migraphx_shape_strides(const size_t** out, size_t* out_size, const_migraphx_shap
*out = api_result.data(); *out = api_result.data();
*out_size = api_result.size(); *out_size = api_result.size();
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_shape_type(migraphx_shape_datatype_t* out, extern "C" migraphx_status migraphx_shape_type(migraphx_shape_datatype_t* out,
const_migraphx_shape_t shape) const_migraphx_shape_t shape)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(out == nullptr) if(out == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter out: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter out: Null pointer");
if(shape == nullptr) if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
*out = migraphx::to_shape_type((shape->object).type()); *out = migraphx::to_shape_type((shape->object).type());
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_shape_bytes(size_t* out, const_migraphx_shape_t shape) extern "C" migraphx_status migraphx_shape_bytes(size_t* out, const_migraphx_shape_t shape)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(shape == nullptr) if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
*out = (shape->object).bytes(); *out = (shape->object).bytes();
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_shape_t x) migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_shape_t x)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(shape == nullptr) if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
if(x == nullptr) if(x == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter x: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter x: Null pointer");
*out = migraphx::equal((shape->object), (x->object)); *out = migraphx::equal((shape->object), (x->object));
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_argument_destroy(migraphx_argument_t argument) extern "C" migraphx_status migraphx_argument_destroy(migraphx_argument_t argument)
{ {
return migraphx::try_([&] { destroy((argument)); }); auto api_error_result = migraphx::try_([&] { destroy((argument)); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_argument_create(migraphx_argument_t* argument, const_migraphx_shape_t shape, void* buffer) migraphx_argument_create(migraphx_argument_t* argument, const_migraphx_shape_t shape, void* buffer)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(shape == nullptr) if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
*argument = object_cast<migraphx_argument_t>( *argument = object_cast<migraphx_argument_t>(
allocate<migraphx::argument>((shape->object), (buffer))); allocate<migraphx::argument>((shape->object), (buffer)));
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_argument_shape(const_migraphx_shape_t* out, extern "C" migraphx_status migraphx_argument_shape(const_migraphx_shape_t* out,
const_migraphx_argument_t argument) const_migraphx_argument_t argument)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(argument == nullptr) if(argument == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter argument: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter argument: Null pointer");
*out = object_cast<const_migraphx_shape_t>(&((argument->object).get_shape())); *out = object_cast<const_migraphx_shape_t>(&((argument->object).get_shape()));
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_argument_buffer(char** out, const_migraphx_argument_t argument) extern "C" migraphx_status migraphx_argument_buffer(char** out, const_migraphx_argument_t argument)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(argument == nullptr) if(argument == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter argument: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter argument: Null pointer");
*out = (argument->object).data(); *out = (argument->object).data();
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_argument_equal(bool* out, const_migraphx_argument_t argument, const_migraphx_argument_t x) migraphx_argument_equal(bool* out, const_migraphx_argument_t argument, const_migraphx_argument_t x)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(argument == nullptr) if(argument == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter argument: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter argument: Null pointer");
if(x == nullptr) if(x == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter x: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter x: Null pointer");
*out = migraphx::equal((argument->object), (x->object)); *out = migraphx::equal((argument->object), (x->object));
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_argument_generate(migraphx_argument_t* out, const_migraphx_shape_t s, size_t seed) migraphx_argument_generate(migraphx_argument_t* out, const_migraphx_shape_t s, size_t seed)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(s == nullptr) if(s == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter s: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter s: Null pointer");
*out = allocate<migraphx_argument_t>(migraphx::generate_argument((s->object), (seed))); *out = allocate<migraphx_argument_t>(migraphx::generate_argument((s->object), (seed)));
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_target_destroy(migraphx_target_t target) extern "C" migraphx_status migraphx_target_destroy(migraphx_target_t target)
{ {
return migraphx::try_([&] { destroy((target)); }); auto api_error_result = migraphx::try_([&] { destroy((target)); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_target_create(migraphx_target_t* target, const char* name) extern "C" migraphx_status migraphx_target_create(migraphx_target_t* target, const char* name)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
*target = object_cast<migraphx_target_t>( *target = object_cast<migraphx_target_t>(
allocate<migraphx::target>(migraphx::get_target((name)))); allocate<migraphx::target>(migraphx::get_target((name))));
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_program_parameter_shapes_destroy( extern "C" migraphx_status migraphx_program_parameter_shapes_destroy(
migraphx_program_parameter_shapes_t program_parameter_shapes) migraphx_program_parameter_shapes_t program_parameter_shapes)
{ {
return migraphx::try_([&] { destroy((program_parameter_shapes)); }); auto api_error_result = migraphx::try_([&] { destroy((program_parameter_shapes)); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_program_parameter_shapes_size(size_t* out, migraphx_program_parameter_shapes_size(size_t* out,
migraphx_program_parameter_shapes_t program_parameter_shapes) migraphx_program_parameter_shapes_t program_parameter_shapes)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(program_parameter_shapes == nullptr) if(program_parameter_shapes == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter program_parameter_shapes: Null pointer"); "Bad parameter program_parameter_shapes: Null pointer");
*out = (program_parameter_shapes->object).size(); *out = (program_parameter_shapes->object).size();
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
...@@ -567,19 +599,20 @@ migraphx_program_parameter_shapes_get(const_migraphx_shape_t* out, ...@@ -567,19 +599,20 @@ migraphx_program_parameter_shapes_get(const_migraphx_shape_t* out,
migraphx_program_parameter_shapes_t program_parameter_shapes, migraphx_program_parameter_shapes_t program_parameter_shapes,
const char* name) const char* name)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(program_parameter_shapes == nullptr) if(program_parameter_shapes == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter program_parameter_shapes: Null pointer"); "Bad parameter program_parameter_shapes: Null pointer");
*out = *out =
object_cast<const_migraphx_shape_t>(&((program_parameter_shapes->object).at((name)))); object_cast<const_migraphx_shape_t>(&((program_parameter_shapes->object).at((name))));
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_program_parameter_shapes_names( extern "C" migraphx_status migraphx_program_parameter_shapes_names(
const char** out, migraphx_program_parameter_shapes_t program_parameter_shapes) const char** out, migraphx_program_parameter_shapes_t program_parameter_shapes)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(out == nullptr) if(out == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter out: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter out: Null pointer");
if(program_parameter_shapes == nullptr) if(program_parameter_shapes == nullptr)
...@@ -588,21 +621,24 @@ extern "C" migraphx_status migraphx_program_parameter_shapes_names( ...@@ -588,21 +621,24 @@ extern "C" migraphx_status migraphx_program_parameter_shapes_names(
auto&& api_result = migraphx::get_names((program_parameter_shapes->object)); auto&& api_result = migraphx::get_names((program_parameter_shapes->object));
std::copy(api_result.begin(), api_result.end(), out); std::copy(api_result.begin(), api_result.end(), out);
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_program_parameters_destroy(migraphx_program_parameters_t program_parameters) migraphx_program_parameters_destroy(migraphx_program_parameters_t program_parameters)
{ {
return migraphx::try_([&] { destroy((program_parameters)); }); auto api_error_result = migraphx::try_([&] { destroy((program_parameters)); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_program_parameters_create(migraphx_program_parameters_t* program_parameters) migraphx_program_parameters_create(migraphx_program_parameters_t* program_parameters)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
*program_parameters = object_cast<migraphx_program_parameters_t>( *program_parameters = object_cast<migraphx_program_parameters_t>(
allocate<std::unordered_map<std::string, migraphx::argument>>()); allocate<std::unordered_map<std::string, migraphx::argument>>());
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
...@@ -610,7 +646,7 @@ migraphx_program_parameters_add(migraphx_program_parameters_t program_parameters ...@@ -610,7 +646,7 @@ migraphx_program_parameters_add(migraphx_program_parameters_t program_parameters
const char* name, const char* name,
const_migraphx_argument_t argument) const_migraphx_argument_t argument)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(program_parameters == nullptr) if(program_parameters == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter program_parameters: Null pointer"); "Bad parameter program_parameters: Null pointer");
...@@ -618,85 +654,95 @@ migraphx_program_parameters_add(migraphx_program_parameters_t program_parameters ...@@ -618,85 +654,95 @@ migraphx_program_parameters_add(migraphx_program_parameters_t program_parameters
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter argument: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter argument: Null pointer");
(program_parameters->object)[(name)] = (argument->object); (program_parameters->object)[(name)] = (argument->object);
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_arguments_destroy(migraphx_arguments_t arguments) extern "C" migraphx_status migraphx_arguments_destroy(migraphx_arguments_t arguments)
{ {
return migraphx::try_([&] { destroy((arguments)); }); auto api_error_result = migraphx::try_([&] { destroy((arguments)); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_arguments_size(size_t* out, migraphx_arguments_t arguments) extern "C" migraphx_status migraphx_arguments_size(size_t* out, migraphx_arguments_t arguments)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(arguments == nullptr) if(arguments == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter arguments: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter arguments: Null pointer");
*out = (arguments->object).size(); *out = (arguments->object).size();
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_arguments_get(const_migraphx_argument_t* out, migraphx_arguments_t arguments, size_t idx) migraphx_arguments_get(const_migraphx_argument_t* out, migraphx_arguments_t arguments, size_t idx)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(arguments == nullptr) if(arguments == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter arguments: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter arguments: Null pointer");
*out = object_cast<const_migraphx_argument_t>(&((arguments->object).at((idx)))); *out = object_cast<const_migraphx_argument_t>(&((arguments->object).at((idx))));
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_shapes_destroy(migraphx_shapes_t shapes) extern "C" migraphx_status migraphx_shapes_destroy(migraphx_shapes_t shapes)
{ {
return migraphx::try_([&] { destroy((shapes)); }); auto api_error_result = migraphx::try_([&] { destroy((shapes)); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_shapes_size(size_t* out, migraphx_shapes_t shapes) extern "C" migraphx_status migraphx_shapes_size(size_t* out, migraphx_shapes_t shapes)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(shapes == nullptr) if(shapes == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shapes: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shapes: Null pointer");
*out = (shapes->object).size(); *out = (shapes->object).size();
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_shapes_get(const_migraphx_shape_t* out, migraphx_shapes_t shapes, size_t idx) migraphx_shapes_get(const_migraphx_shape_t* out, migraphx_shapes_t shapes, size_t idx)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(shapes == nullptr) if(shapes == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shapes: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shapes: Null pointer");
*out = object_cast<const_migraphx_shape_t>(&((shapes->object).at((idx)))); *out = object_cast<const_migraphx_shape_t>(&((shapes->object).at((idx))));
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_module_print(const_migraphx_module_t module) extern "C" migraphx_status migraphx_module_print(const_migraphx_module_t module)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(module == nullptr) if(module == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter module: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter module: Null pointer");
migraphx::print_module((module->object)); migraphx::print_module((module->object));
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_program_destroy(migraphx_program_t program) extern "C" migraphx_status migraphx_program_destroy(migraphx_program_t program)
{ {
return migraphx::try_([&] { destroy((program)); }); auto api_error_result = migraphx::try_([&] { destroy((program)); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_program_get_main_module(migraphx_module_t* out, extern "C" migraphx_status migraphx_program_get_main_module(migraphx_module_t* out,
migraphx_program_t program) migraphx_program_t program)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(program == nullptr) if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
*out = object_cast<migraphx_module_t>((program->object).get_main_module()); *out = object_cast<migraphx_module_t>((program->object).get_main_module());
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_program_compile(migraphx_program_t program, extern "C" migraphx_status migraphx_program_compile(migraphx_program_t program,
migraphx_target_t target, migraphx_target_t target,
migraphx_compile_options_t options) migraphx_compile_options_t options)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(program == nullptr) if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
if(target == nullptr) if(target == nullptr)
...@@ -705,91 +751,105 @@ extern "C" migraphx_status migraphx_program_compile(migraphx_program_t program, ...@@ -705,91 +751,105 @@ extern "C" migraphx_status migraphx_program_compile(migraphx_program_t program,
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
(program->object).compile((target->object), (options->object)); (program->object).compile((target->object), (options->object));
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_program_get_parameter_shapes(migraphx_program_parameter_shapes_t* out, migraphx_program_get_parameter_shapes(migraphx_program_parameter_shapes_t* out,
migraphx_program_t program) migraphx_program_t program)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(program == nullptr) if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
*out = *out =
allocate<migraphx_program_parameter_shapes_t>((program->object).get_parameter_shapes()); allocate<migraphx_program_parameter_shapes_t>((program->object).get_parameter_shapes());
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_program_get_output_shapes(migraphx_shapes_t* out, extern "C" migraphx_status migraphx_program_get_output_shapes(migraphx_shapes_t* out,
migraphx_program_t program) migraphx_program_t program)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(program == nullptr) if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
*out = allocate<migraphx_shapes_t>(migraphx::get_output_shapes((program->object))); *out = allocate<migraphx_shapes_t>(migraphx::get_output_shapes((program->object)));
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_program_print(const_migraphx_program_t program) extern "C" migraphx_status migraphx_program_print(const_migraphx_program_t program)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(program == nullptr) if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
migraphx::print_program((program->object)); migraphx::print_program((program->object));
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_program_sort(migraphx_program_t program) extern "C" migraphx_status migraphx_program_sort(migraphx_program_t program)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(program == nullptr) if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
(program->object).sort(); (program->object).sort();
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_program_run(migraphx_arguments_t* out, extern "C" migraphx_status migraphx_program_run(migraphx_arguments_t* out,
migraphx_program_t program, migraphx_program_t program,
migraphx_program_parameters_t params) migraphx_program_parameters_t params)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(program == nullptr) if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
if(params == nullptr) if(params == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter params: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter params: Null pointer");
*out = allocate<migraphx_arguments_t>(migraphx::run((program->object), (params->object))); *out = allocate<migraphx_arguments_t>(migraphx::run((program->object), (params->object)));
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_program_equal(bool* out, const_migraphx_program_t program, const_migraphx_program_t x) migraphx_program_equal(bool* out, const_migraphx_program_t program, const_migraphx_program_t x)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(program == nullptr) if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
if(x == nullptr) if(x == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter x: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter x: Null pointer");
*out = migraphx::equal((program->object), (x->object)); *out = migraphx::equal((program->object), (x->object));
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_operation_destroy(migraphx_operation_t operation) extern "C" migraphx_status migraphx_operation_destroy(migraphx_operation_t operation)
{ {
return migraphx::try_([&] { destroy((operation)); }); auto api_error_result = migraphx::try_([&] { destroy((operation)); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status migraphx_operation_create(migraphx_operation_t* operation,
migraphx_operation_create(migraphx_operation_t* operation, const char* name, const char* attributes) const char* name,
const char* attributes,
...)
{ {
return migraphx::try_([&] { va_list vlist;
va_start(vlist, attributes);
auto api_error_result = migraphx::try_([&] {
*operation = object_cast<migraphx_operation_t>( *operation = object_cast<migraphx_operation_t>(
allocate<migraphx::operation>(migraphx::create_op((name), (attributes)))); allocate<migraphx::operation>(migraphx::create_op((name), (attributes), (vlist))));
}); });
va_end(vlist);
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_operation_name(char* out, size_t out_size, migraphx_operation_t operation) migraphx_operation_name(char* out, size_t out_size, migraphx_operation_t operation)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(out == nullptr) if(out == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter out: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter out: Null pointer");
if(operation == nullptr) if(operation == nullptr)
...@@ -798,46 +858,51 @@ migraphx_operation_name(char* out, size_t out_size, migraphx_operation_t operati ...@@ -798,46 +858,51 @@ migraphx_operation_name(char* out, size_t out_size, migraphx_operation_t operati
auto* it = std::copy_n(api_result.begin(), std::min(api_result.size(), out_size - 1), out); auto* it = std::copy_n(api_result.begin(), std::min(api_result.size(), out_size - 1), out);
*it = '\0'; *it = '\0';
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_load(migraphx_program_t* out, const char* name, migraphx_file_options_t options) migraphx_load(migraphx_program_t* out, const char* name, migraphx_file_options_t options)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(options == nullptr) if(options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
*out = allocate<migraphx_program_t>(migraphx::load((name), (options->object))); *out = allocate<migraphx_program_t>(migraphx::load((name), (options->object)));
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_save(migraphx_program_t p, const char* name, migraphx_file_options_t options) migraphx_save(migraphx_program_t p, const char* name, migraphx_file_options_t options)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(p == nullptr) if(p == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter p: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter p: Null pointer");
if(options == nullptr) if(options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
migraphx::save((p->object), (name), (options->object)); migraphx::save((p->object), (name), (options->object));
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_onnx_options_destroy(migraphx_onnx_options_t onnx_options) extern "C" migraphx_status migraphx_onnx_options_destroy(migraphx_onnx_options_t onnx_options)
{ {
return migraphx::try_([&] { destroy((onnx_options)); }); auto api_error_result = migraphx::try_([&] { destroy((onnx_options)); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_onnx_options_create(migraphx_onnx_options_t* onnx_options) extern "C" migraphx_status migraphx_onnx_options_create(migraphx_onnx_options_t* onnx_options)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
*onnx_options = object_cast<migraphx_onnx_options_t>(allocate<migraphx::onnx_options>()); *onnx_options = object_cast<migraphx_onnx_options_t>(allocate<migraphx::onnx_options>());
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_onnx_options_set_input_parameter_shape( extern "C" migraphx_status migraphx_onnx_options_set_input_parameter_shape(
migraphx_onnx_options_t onnx_options, const char* name, size_t* dims, size_t dims_size) migraphx_onnx_options_t onnx_options, const char* name, size_t* dims, size_t dims_size)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(onnx_options == nullptr) if(onnx_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter onnx_options: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter onnx_options: Null pointer");
if(dims == nullptr and dims_size != 0) if(dims == nullptr and dims_size != 0)
...@@ -845,96 +910,107 @@ extern "C" migraphx_status migraphx_onnx_options_set_input_parameter_shape( ...@@ -845,96 +910,107 @@ extern "C" migraphx_status migraphx_onnx_options_set_input_parameter_shape(
migraphx::set_input_parameter_shape( migraphx::set_input_parameter_shape(
(onnx_options->object), (name), (std::vector<size_t>(dims, dims + dims_size))); (onnx_options->object), (name), (std::vector<size_t>(dims, dims + dims_size)));
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_onnx_options_set_default_dim_value(migraphx_onnx_options_t onnx_options, size_t value) migraphx_onnx_options_set_default_dim_value(migraphx_onnx_options_t onnx_options, size_t value)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(onnx_options == nullptr) if(onnx_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter onnx_options: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter onnx_options: Null pointer");
migraphx::set_default_dim_value((onnx_options->object), (value)); migraphx::set_default_dim_value((onnx_options->object), (value));
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_onnx_options_set_default_loop_iterations(migraphx_onnx_options_t onnx_options, migraphx_onnx_options_set_default_loop_iterations(migraphx_onnx_options_t onnx_options,
int64_t value) int64_t value)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(onnx_options == nullptr) if(onnx_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter onnx_options: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter onnx_options: Null pointer");
migraphx::set_default_loop_iterations((onnx_options->object), (value)); migraphx::set_default_loop_iterations((onnx_options->object), (value));
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_file_options_destroy(migraphx_file_options_t file_options) extern "C" migraphx_status migraphx_file_options_destroy(migraphx_file_options_t file_options)
{ {
return migraphx::try_([&] { destroy((file_options)); }); auto api_error_result = migraphx::try_([&] { destroy((file_options)); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_file_options_create(migraphx_file_options_t* file_options) extern "C" migraphx_status migraphx_file_options_create(migraphx_file_options_t* file_options)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
*file_options = object_cast<migraphx_file_options_t>(allocate<migraphx::file_options>()); *file_options = object_cast<migraphx_file_options_t>(allocate<migraphx::file_options>());
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_file_options_set_file_format(migraphx_file_options_t file_options, const char* format) migraphx_file_options_set_file_format(migraphx_file_options_t file_options, const char* format)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(file_options == nullptr) if(file_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter file_options: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter file_options: Null pointer");
migraphx::set_file_format((file_options->object), (format)); migraphx::set_file_format((file_options->object), (format));
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_compile_options_destroy(migraphx_compile_options_t compile_options) migraphx_compile_options_destroy(migraphx_compile_options_t compile_options)
{ {
return migraphx::try_([&] { destroy((compile_options)); }); auto api_error_result = migraphx::try_([&] { destroy((compile_options)); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_compile_options_create(migraphx_compile_options_t* compile_options) migraphx_compile_options_create(migraphx_compile_options_t* compile_options)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
*compile_options = *compile_options =
object_cast<migraphx_compile_options_t>(allocate<migraphx::compile_options>()); object_cast<migraphx_compile_options_t>(allocate<migraphx::compile_options>());
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_compile_options_set_offload_copy(migraphx_compile_options_t compile_options, bool value) migraphx_compile_options_set_offload_copy(migraphx_compile_options_t compile_options, bool value)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(compile_options == nullptr) if(compile_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter compile_options: Null pointer"); "Bad parameter compile_options: Null pointer");
migraphx::set_offload_copy((compile_options->object), (value)); migraphx::set_offload_copy((compile_options->object), (value));
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_compile_options_set_fast_math(migraphx_compile_options_t compile_options, bool value) migraphx_compile_options_set_fast_math(migraphx_compile_options_t compile_options, bool value)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(compile_options == nullptr) if(compile_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter compile_options: Null pointer"); "Bad parameter compile_options: Null pointer");
migraphx::set_fast_math((compile_options->object), (value)); migraphx::set_fast_math((compile_options->object), (value));
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_parse_onnx(migraphx_program_t* out, const char* name, migraphx_onnx_options_t options) migraphx_parse_onnx(migraphx_program_t* out, const char* name, migraphx_onnx_options_t options)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(options == nullptr) if(options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
*out = allocate<migraphx_program_t>(migraphx::parse_onnx((name), (options->object))); *out = allocate<migraphx_program_t>(migraphx::parse_onnx((name), (options->object)));
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_parse_onnx_buffer(migraphx_program_t* out, extern "C" migraphx_status migraphx_parse_onnx_buffer(migraphx_program_t* out,
...@@ -942,40 +1018,44 @@ extern "C" migraphx_status migraphx_parse_onnx_buffer(migraphx_program_t* out, ...@@ -942,40 +1018,44 @@ extern "C" migraphx_status migraphx_parse_onnx_buffer(migraphx_program_t* out,
size_t size, size_t size,
migraphx_onnx_options_t options) migraphx_onnx_options_t options)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(options == nullptr) if(options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
*out = allocate<migraphx_program_t>( *out = allocate<migraphx_program_t>(
migraphx::parse_onnx_buffer((data), (size), (options->object))); migraphx::parse_onnx_buffer((data), (size), (options->object)));
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_tf_options_destroy(migraphx_tf_options_t tf_options) extern "C" migraphx_status migraphx_tf_options_destroy(migraphx_tf_options_t tf_options)
{ {
return migraphx::try_([&] { destroy((tf_options)); }); auto api_error_result = migraphx::try_([&] { destroy((tf_options)); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_tf_options_create(migraphx_tf_options_t* tf_options) extern "C" migraphx_status migraphx_tf_options_create(migraphx_tf_options_t* tf_options)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
*tf_options = object_cast<migraphx_tf_options_t>(allocate<migraphx::tf_options>()); *tf_options = object_cast<migraphx_tf_options_t>(allocate<migraphx::tf_options>());
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_tf_options_set_nhwc(migraphx_tf_options_t tf_options, extern "C" migraphx_status migraphx_tf_options_set_nhwc(migraphx_tf_options_t tf_options,
bool is_nhwc) bool is_nhwc)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(tf_options == nullptr) if(tf_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter tf_options: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter tf_options: Null pointer");
migraphx::set_nhwc((tf_options->object), (is_nhwc)); migraphx::set_nhwc((tf_options->object), (is_nhwc));
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_tf_options_set_input_parameter_shape( extern "C" migraphx_status migraphx_tf_options_set_input_parameter_shape(
migraphx_tf_options_t tf_options, const char* name, size_t* dims, size_t dims_size) migraphx_tf_options_t tf_options, const char* name, size_t* dims, size_t dims_size)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(tf_options == nullptr) if(tf_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter tf_options: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter tf_options: Null pointer");
if(dims == nullptr and dims_size != 0) if(dims == nullptr and dims_size != 0)
...@@ -983,23 +1063,25 @@ extern "C" migraphx_status migraphx_tf_options_set_input_parameter_shape( ...@@ -983,23 +1063,25 @@ extern "C" migraphx_status migraphx_tf_options_set_input_parameter_shape(
migraphx::set_input_parameter_shape( migraphx::set_input_parameter_shape(
(tf_options->object), (name), (std::vector<size_t>(dims, dims + dims_size))); (tf_options->object), (name), (std::vector<size_t>(dims, dims + dims_size)));
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_tf_options_set_default_dim_value(migraphx_tf_options_t tf_options, size_t value) migraphx_tf_options_set_default_dim_value(migraphx_tf_options_t tf_options, size_t value)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(tf_options == nullptr) if(tf_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter tf_options: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter tf_options: Null pointer");
migraphx::set_default_dim_value((tf_options->object), (value)); migraphx::set_default_dim_value((tf_options->object), (value));
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_tf_options_set_output_names(migraphx_tf_options_t tf_options, extern "C" migraphx_status migraphx_tf_options_set_output_names(migraphx_tf_options_t tf_options,
const char** names, const char** names,
size_t names_size) size_t names_size)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(tf_options == nullptr) if(tf_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter tf_options: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter tf_options: Null pointer");
if(names == nullptr and names_size != 0) if(names == nullptr and names_size != 0)
...@@ -1007,96 +1089,106 @@ extern "C" migraphx_status migraphx_tf_options_set_output_names(migraphx_tf_opti ...@@ -1007,96 +1089,106 @@ extern "C" migraphx_status migraphx_tf_options_set_output_names(migraphx_tf_opti
migraphx::set_output_names((tf_options->object), migraphx::set_output_names((tf_options->object),
(std::vector<const char*>(names, names + names_size))); (std::vector<const char*>(names, names + names_size)));
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_parse_tf(migraphx_program_t* out, const char* name, migraphx_tf_options_t options) migraphx_parse_tf(migraphx_program_t* out, const char* name, migraphx_tf_options_t options)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(options == nullptr) if(options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
*out = allocate<migraphx_program_t>(migraphx::parse_tf((name), (options->object))); *out = allocate<migraphx_program_t>(migraphx::parse_tf((name), (options->object)));
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_quantize_op_names_destroy(migraphx_quantize_op_names_t quantize_op_names) migraphx_quantize_op_names_destroy(migraphx_quantize_op_names_t quantize_op_names)
{ {
return migraphx::try_([&] { destroy((quantize_op_names)); }); auto api_error_result = migraphx::try_([&] { destroy((quantize_op_names)); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_quantize_op_names_create(migraphx_quantize_op_names_t* quantize_op_names) migraphx_quantize_op_names_create(migraphx_quantize_op_names_t* quantize_op_names)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
*quantize_op_names = *quantize_op_names =
object_cast<migraphx_quantize_op_names_t>(allocate<std::vector<std::string>>()); object_cast<migraphx_quantize_op_names_t>(allocate<std::vector<std::string>>());
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_quantize_op_names_add(migraphx_quantize_op_names_t quantize_op_names, const char* name) migraphx_quantize_op_names_add(migraphx_quantize_op_names_t quantize_op_names, const char* name)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(quantize_op_names == nullptr) if(quantize_op_names == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter quantize_op_names: Null pointer"); "Bad parameter quantize_op_names: Null pointer");
(quantize_op_names->object).push_back((name)); (quantize_op_names->object).push_back((name));
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_quantize_fp16_with_op_names(migraphx_program_t prog, extern "C" migraphx_status migraphx_quantize_fp16_with_op_names(migraphx_program_t prog,
migraphx_quantize_op_names_t name) migraphx_quantize_op_names_t name)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(prog == nullptr) if(prog == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter prog: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter prog: Null pointer");
if(name == nullptr) if(name == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter name: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter name: Null pointer");
migraphx::quantize_fp16_with_op_names((prog->object), (name->object)); migraphx::quantize_fp16_with_op_names((prog->object), (name->object));
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_quantize_fp16(migraphx_program_t prog) extern "C" migraphx_status migraphx_quantize_fp16(migraphx_program_t prog)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(prog == nullptr) if(prog == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter prog: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter prog: Null pointer");
migraphx::quantize_fp16((prog->object)); migraphx::quantize_fp16((prog->object));
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_quantize_int8_options_destroy(migraphx_quantize_int8_options_t quantize_int8_options) migraphx_quantize_int8_options_destroy(migraphx_quantize_int8_options_t quantize_int8_options)
{ {
return migraphx::try_([&] { destroy((quantize_int8_options)); }); auto api_error_result = migraphx::try_([&] { destroy((quantize_int8_options)); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_quantize_int8_options_create(migraphx_quantize_int8_options_t* quantize_int8_options) migraphx_quantize_int8_options_create(migraphx_quantize_int8_options_t* quantize_int8_options)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
*quantize_int8_options = object_cast<migraphx_quantize_int8_options_t>( *quantize_int8_options = object_cast<migraphx_quantize_int8_options_t>(
allocate<migraphx::quantize_int8_options>()); allocate<migraphx::quantize_int8_options>());
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_quantize_int8_options_add_op_name(migraphx_quantize_int8_options_t quantize_int8_options, migraphx_quantize_int8_options_add_op_name(migraphx_quantize_int8_options_t quantize_int8_options,
const char* name) const char* name)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(quantize_int8_options == nullptr) if(quantize_int8_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter quantize_int8_options: Null pointer"); "Bad parameter quantize_int8_options: Null pointer");
migraphx::add_op_name((quantize_int8_options->object), (name)); migraphx::add_op_name((quantize_int8_options->object), (name));
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_quantize_int8_options_add_calibration_data( extern "C" migraphx_status migraphx_quantize_int8_options_add_calibration_data(
migraphx_quantize_int8_options_t quantize_int8_options, migraphx_program_parameters_t data) migraphx_quantize_int8_options_t quantize_int8_options, migraphx_program_parameters_t data)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(quantize_int8_options == nullptr) if(quantize_int8_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter quantize_int8_options: Null pointer"); "Bad parameter quantize_int8_options: Null pointer");
...@@ -1104,13 +1196,14 @@ extern "C" migraphx_status migraphx_quantize_int8_options_add_calibration_data( ...@@ -1104,13 +1196,14 @@ extern "C" migraphx_status migraphx_quantize_int8_options_add_calibration_data(
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter data: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter data: Null pointer");
migraphx::add_calibration_data((quantize_int8_options->object), (data->object)); migraphx::add_calibration_data((quantize_int8_options->object), (data->object));
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_quantize_int8(migraphx_program_t prog, extern "C" migraphx_status migraphx_quantize_int8(migraphx_program_t prog,
migraphx_target_t target, migraphx_target_t target,
migraphx_quantize_int8_options_t options) migraphx_quantize_int8_options_t options)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(prog == nullptr) if(prog == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter prog: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter prog: Null pointer");
if(target == nullptr) if(target == nullptr)
...@@ -1119,4 +1212,5 @@ extern "C" migraphx_status migraphx_quantize_int8(migraphx_program_t prog, ...@@ -1119,4 +1212,5 @@ extern "C" migraphx_status migraphx_quantize_int8(migraphx_program_t prog,
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
migraphx::quantize_int8_wrap((prog->object), (target->object), (options->object)); migraphx::quantize_int8_wrap((prog->object), (target->object), (options->object));
}); });
return api_error_result;
} }
...@@ -209,7 +209,8 @@ migraphx_status migraphx_operation_destroy(migraphx_operation_t operation); ...@@ -209,7 +209,8 @@ migraphx_status migraphx_operation_destroy(migraphx_operation_t operation);
migraphx_status migraphx_operation_create(migraphx_operation_t* operation, migraphx_status migraphx_operation_create(migraphx_operation_t* operation,
const char* name, const char* name,
const char* attributes); const char* attributes,
...);
migraphx_status migraphx_operation_name(char* out, size_t out_size, migraphx_operation_t operation); migraphx_status migraphx_operation_name(char* out, size_t out_size, migraphx_operation_t operation);
......
...@@ -252,7 +252,7 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape) ...@@ -252,7 +252,7 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
const size_t* pout; const size_t* pout;
size_t pout_size; size_t pout_size;
call(&migraphx_shape_lengths, &pout, &pout_size, this->get_handle_ptr()); call(&migraphx_shape_lengths, &pout, &pout_size, this->get_handle_ptr());
return std::vector<size_t>(pout, pout + pout_size); return {pout, pout + pout_size};
} }
std::vector<size_t> strides() const std::vector<size_t> strides() const
...@@ -260,7 +260,7 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape) ...@@ -260,7 +260,7 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
const size_t* pout; const size_t* pout;
size_t pout_size; size_t pout_size;
call(&migraphx_shape_strides, &pout, &pout_size, this->get_handle_ptr()); call(&migraphx_shape_strides, &pout, &pout_size, this->get_handle_ptr());
return std::vector<size_t>(pout, pout + pout_size); return {pout, pout + pout_size};
} }
migraphx_shape_datatype_t type() const migraphx_shape_datatype_t type() const
...@@ -312,7 +312,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument) ...@@ -312,7 +312,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
{ {
const_migraphx_shape_t pout; const_migraphx_shape_t pout;
call(&migraphx_argument_shape, &pout, this->get_handle_ptr()); call(&migraphx_argument_shape, &pout, this->get_handle_ptr());
return shape(pout); return {pout};
} }
char* data() const char* data() const
...@@ -325,9 +325,8 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument) ...@@ -325,9 +325,8 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
/// Generate an argument using random data /// Generate an argument using random data
static argument generate(shape ps, size_t pseed = 0) static argument generate(shape ps, size_t pseed = 0)
{ {
return argument( return {make<migraphx_argument>(&migraphx_argument_generate, ps.get_handle_ptr(), pseed),
make<migraphx_argument>(&migraphx_argument_generate, ps.get_handle_ptr(), pseed), own{}};
own{});
} }
friend bool operator==(const argument& px, const argument& py) friend bool operator==(const argument& px, const argument& py)
...@@ -378,7 +377,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes) ...@@ -378,7 +377,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
{ {
const_migraphx_shape_t pout; const_migraphx_shape_t pout;
call(&migraphx_program_parameter_shapes_get, &pout, this->get_handle_ptr(), pname); call(&migraphx_program_parameter_shapes_get, &pout, this->get_handle_ptr(), pname);
return shape(pout); return {pout};
} }
std::vector<const char*> names() const std::vector<const char*> names() const
...@@ -438,7 +437,7 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments> ...@@ -438,7 +437,7 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments>
{ {
const_migraphx_argument_t pout; const_migraphx_argument_t pout;
call(&migraphx_arguments_get, &pout, this->get_handle_ptr(), pidx); call(&migraphx_arguments_get, &pout, this->get_handle_ptr(), pidx);
return argument(pout); return {pout};
} }
struct iterator_read struct iterator_read
...@@ -449,7 +448,7 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments> ...@@ -449,7 +448,7 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments>
const_migraphx_argument_t pout; const_migraphx_argument_t pout;
call(&migraphx_arguments_get, &pout, self, pidx); call(&migraphx_arguments_get, &pout, self, pidx);
return argument(pout); return {pout};
} }
}; };
}; };
...@@ -471,7 +470,7 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes> ...@@ -471,7 +470,7 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
{ {
const_migraphx_shape_t pout; const_migraphx_shape_t pout;
call(&migraphx_shapes_get, &pout, this->get_handle_ptr(), pidx); call(&migraphx_shapes_get, &pout, this->get_handle_ptr(), pidx);
return shape(pout); return {pout};
} }
struct iterator_read struct iterator_read
...@@ -481,7 +480,7 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes> ...@@ -481,7 +480,7 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
{ {
const_migraphx_shape_t pout; const_migraphx_shape_t pout;
call(&migraphx_shapes_get, &pout, self, pidx); call(&migraphx_shapes_get, &pout, self, pidx);
return shape(pout); return {pout};
} }
}; };
}; };
...@@ -599,16 +598,17 @@ struct operation : MIGRAPHX_HANDLE_BASE(operation) ...@@ -599,16 +598,17 @@ struct operation : MIGRAPHX_HANDLE_BASE(operation)
operation(migraphx_operation* p, borrow) { this->set_handle(p, borrow{}); } operation(migraphx_operation* p, borrow) { this->set_handle(p, borrow{}); }
operation(const char* name, const char* attributes = nullptr) template <class... Ts>
operation(const char* name, const char* attributes = nullptr, Ts... xs)
{ {
this->make_handle(&migraphx_operation_create, name, attributes); this->make_handle(&migraphx_operation_create, name, attributes, xs...);
} }
std::string name() std::string name()
{ {
std::array<char, 1024> out_name; std::array<char, 1024> out_name;
call(&migraphx_operation_name, out_name.data(), 1024, this->get_handle_ptr()); call(&migraphx_operation_name, out_name.data(), 1024, this->get_handle_ptr());
return std::string(out_name.data()); return {out_name.data()};
} }
}; };
......
...@@ -212,7 +212,9 @@ def program(h): ...@@ -212,7 +212,9 @@ def program(h):
@auto_handle() @auto_handle()
def operation(h): def operation(h):
h.constructor('create', h.constructor('create',
api.params(name='const char*', attributes='const char*'), api.params(name='const char*',
attributes='const char*',
vlist='...'),
fname='migraphx::create_op') fname='migraphx::create_op')
h.method('name', returns='std::string') h.method('name', returns='std::string')
......
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/common.hpp>
#include <migraphx/apply_alpha_beta.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
instruction_ref insert_apply_alpha_beta(module& m,
instruction_ref pos,
const std::vector<instruction_ref>& args,
const operation& op,
const literal& alpha,
const literal& beta)
{
auto a = args[0];
auto b = args[1];
auto input_type = a->get_shape().type();
if(!float_equal(alpha.at<float>(0), 1.0))
{
auto alpha_literal = m.add_literal(alpha);
a = insert_common_op(m, pos, migraphx::make_op("mul"), {alpha_literal, a});
if(a->get_shape().type() != input_type)
{
a = m.insert_instruction(pos, make_op("convert", {{"target_type", input_type}}), a);
}
}
auto op_res = m.insert_instruction(pos, op, a, b);
if(args.size() == 3)
{
if(not float_equal(beta.at<float>(0), 0.0) && args[2]->get_shape().elements() > 0)
{
auto out_lens = op_res->get_shape().lens();
auto c = args[2];
auto c_lens = c->get_shape().lens();
input_type = c->get_shape().type();
if(out_lens != c_lens)
{
c = m.insert_instruction(
pos, migraphx::make_op("multibroadcast", {{"out_lens", out_lens}}), args[2]);
}
auto beta_literal = m.add_literal(beta);
auto beta_c = insert_common_op(m, pos, migraphx::make_op("mul"), {c, beta_literal});
if(beta_c->get_shape().type() != input_type)
{
beta_c = m.insert_instruction(
pos, migraphx::make_op("convert", {{"target_type", input_type}}), beta_c);
}
return m.insert_instruction(pos, migraphx::make_op("add"), op_res, beta_c);
}
}
return op_res;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment