Commit 0369e974 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

Merge branch 'batch_report' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into mi100_opts

parents 3a474fca d70fd0df
......@@ -80,10 +80,10 @@ jobs:
uses: pat-s/always-upload-cache@v2.1.3
with:
path: cppcheck-cache
key: cppcheck-cache-${{ steps.cache_timestamp.outputs.timestamp }}
key: cppcheck-cache-${{ hashFiles('cppcheck.rules', 'CMakeLists.txt') }}-${{ steps.cache_timestamp.outputs.timestamp }}
restore-keys: |
cppcheck-cache-${{ steps.cache_timestamp.outputs.timestamp }}
cppcheck-cache-
cppcheck-cache-${{ hashFiles('cppcheck.rules', 'CMakeLists.txt') }}-${{ steps.cache_timestamp.outputs.timestamp }}
cppcheck-cache-${{ hashFiles('cppcheck.rules', 'CMakeLists.txt') }}-
- name: Build the Docker image
run: docker build . --file hip-clang.docker --tag migraphx
......
......@@ -36,7 +36,7 @@ find_package(nlohmann_json 3.8.0 REQUIRED)
include(ROCMSetupVersion)
rocm_setup_version(VERSION 1.3)
rocm_setup_version(VERSION 2.0)
set(MIGRAPHX_SO_VERSION ${PROJECT_VERSION_MAJOR}.${PROJECT_VERSION_MINOR})
option( BUILD_SHARED_LIBS "Build as a shared library" ON )
......@@ -190,6 +190,7 @@ rocm_enable_cppcheck(
shadowVariable
unsafeClassDivZero
definePrefix:*test/include/test.hpp
ctuOneDefinitionRuleViolation:*test/*
useSmartPointer:*src/api/api.cpp
useSmartPointer:*make_shared_array.hpp
constParameter:*src/targets/gpu/*.cpp
......
......@@ -6,7 +6,7 @@ ARG PREFIX=/usr/local
RUN dpkg --add-architecture i386
# Add rocm repository
RUN sh -c 'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/4.2/ xenial main > /etc/apt/sources.list.d/rocm.list'
RUN sh -c 'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/4.5/ ubuntu main > /etc/apt/sources.list.d/rocm.list'
# Install dependencies
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \
......@@ -32,6 +32,8 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
software-properties-common \
wget \
rocm-device-libs \
hip-base \
libnuma-dev \
miopen-hip \
rocblas \
zlib1g-dev && \
......
......@@ -20,7 +20,7 @@ def rocmtestnode(Map conf) {
rm -rf build
mkdir build
cd build
CXX=${compiler} CXXFLAGS='-Werror -Wno-fallback' cmake -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache ${flags} ..
CXX=${compiler} CXXFLAGS='-Werror' cmake -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache ${flags} ..
make -j\$(nproc) generate all doc package check VERBOSE=1
"""
echo cmd
......@@ -75,6 +75,8 @@ def rocmnodename(name) {
node_name = "${rocmtest_name} && fiji";
} else if(name == "vega") {
node_name = "${rocmtest_name} && vega";
} else if(name == "navi21") {
node_name = "${rocmtest_name} && navi21";
} else if(name == "nogpu") {
return rocmtest_name;
}
......@@ -110,6 +112,10 @@ rocmtest clang_debug: rocmnode('vega') { cmake_build ->
def debug_flags = "-g -O2 -fno-omit-frame-pointer -fsanitize=${sanitizers} -fno-sanitize-recover=${sanitizers}"
cmake_build("/opt/rocm/llvm/bin/clang++", "-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_PYTHON=Off -DMIGRAPHX_ENABLE_GPU=Off -DMIGRAPHX_ENABLE_CPU=On -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}'")
}
}, clang_release_navi: rocmnode('navi21') { cmake_build ->
stage('HIP Clang Release Navi') {
cmake_build("/opt/rocm/llvm/bin/clang++", "-DCMAKE_BUILD_TYPE=release")
}
}
def onnxnode(name, body) {
......
......@@ -152,6 +152,24 @@
<summary>Else statement is not necessary.</summary>
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[((?:(?:\w+|<|>|::) )*(?:\w+|>)(?: &|\*)*) (\w+) ; \2 = static_cast < \1 > (\([^()]*(?-1)*[^()]*\)) ;]]></pattern>
<message>
<id>RedundantCast</id>
<severity>style</severity>
<summary>Static cast is redundant.</summary>
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[auto (\w+) ; \1 = static_cast < (?:(?:\w+|<|>|::) )*(?:\w+|>)(?: &|\*)* > (\([^()]*(?-1)*[^()]*\)) ;]]></pattern>
<message>
<id>RedundantCast</id>
<severity>style</severity>
<summary>Static cast is redundant.</summary>
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[\? (true|false) : (true|false)]]></pattern>
......
pfultz2/rocm-recipes
facebook/zstd@v1.4.5 -X subdir -DCMAKE_DIR=build/cmake
ccache@v4.1
danmar/cppcheck@4a8a78a9258fd56bc21e55b5b49a0f09bc8fa750 -DHAVE_RULES=1
pcre,pfultz2/pcre@8.45 -H sha256:d6f7182602a775a7d500a0cedca6449af0400c6493951513046d17615ed0bf11
danmar/cppcheck@2.6 -DHAVE_RULES=1
RadeonOpenCompute/rocm-cmake@ececd2eccae4d01e7ec154efe90ac43ebf4df317 --build
-f requirements.txt
sphinx==2.2.2
breathe==4.13.1
docutils==0.17.1
sphinx==4.2.0
breathe==4.31.0
sphinx_rtd_theme==1.0.0
# git+https://github.com/arximboldi/breathe@fix-node-parent
......@@ -18,6 +18,8 @@
#
# import os
# import sys
from datetime import date
import re
# sys.path.insert(0, os.path.abspath('.'))
# -- General configuration ------------------------------------------------
......@@ -29,7 +31,9 @@
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = ['breathe', 'sphinx.ext.mathjax', 'sphinx.ext.viewcode']
extensions = [
'breathe', 'sphinx.ext.mathjax', 'sphinx.ext.viewcode', 'sphinx_rtd_theme'
]
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
......@@ -45,7 +49,7 @@ master_doc = 'index'
# General information about the project.
project = u'MIGraphX'
copyright = u'2018, AMD'
copyright = u'2018-{}, AMD'.format(date.today().year)
author = u'AMD'
# The version info for the project you're documenting, acts as replacement for
......@@ -53,9 +57,12 @@ author = u'AMD'
# built documents.
#
# The short X.Y version.
version = u'0.1'
with open('../../CMakeLists.txt') as file:
version = next((re.findall('[0-9.]+', line)[0]
for line in file.readlines()
if 'rocm_setup_version' in line))
# The full version, including alpha/beta/rc tags.
release = u'0.1'
release = version
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
......@@ -82,7 +89,7 @@ todo_include_todos = False
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = 'alabaster'
html_theme = 'sphinx_rtd_theme'
# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
......
......@@ -25,8 +25,8 @@ migraphx::save(p, output_file);
```
migraphx::program p = ... <migraphx::program>;
migraphx_file_options options;
options.format = "msgpack";
migraphx::file_options options;
options.set_file_format("msgpack");
migraphx::save(p, output_file, options);
```
......@@ -41,15 +41,15 @@ p = migraphx::load(input_file);
```
migraphx::program p;
migraphx_file_options options;
options.format = "msgpack";
migraphx::file_options options;
options.set_file_format("msgpack");
p = migraphx::load(input_file, options);
```
To load a program that has been saved in JSON format:
```
migraphx::program p;
migraphx_file_options options;
options.format = "json";
migraphx::file_options options;
options.set_file_format("json");
p = migraphx::load(input_file, options);
```
......
......@@ -44,15 +44,15 @@ int main(int argc, char** argv)
std::string format = load_arg;
if(format == "json")
{
migraphx_file_options options;
options.format = "json";
p = migraphx::load(input_file, options);
migraphx::file_options options;
options.set_file_format("json");
p = migraphx::load(input_file, options);
}
else if(format == "msgpack")
{
migraphx_file_options options;
options.format = "msgpack";
p = migraphx::load(input_file, options);
migraphx::file_options options;
options.set_file_format("msgpack");
p = migraphx::load(input_file, options);
}
else
p = migraphx::load(input_file);
......@@ -80,8 +80,8 @@ int main(int argc, char** argv)
output_file = save_arg == nullptr ? "out" : save_arg;
output_file.append(".msgpack");
migraphx_file_options options;
options.format = "msgpack";
migraphx::file_options options;
options.set_file_format("msgpack");
migraphx::save(p, output_file.c_str(), options);
std::cout << "Program has been saved as ./" << output_file << std::endl;
}
......
......@@ -60,14 +60,14 @@ migraphx::quantize_int8(prog, targ, quant_opts);
## Compilation
Network graphs saved in e.g. ONNX or protobuf format are not target-specific. In order to run inference, we must compile the graph into a target-specific program.
Two options may be turned on (default for both is `false`) when compiling:
- `bool offload_copy`: For targets with offloaded memory (such as the gpu), this will insert instructions during compilation to copy the input parameters to the offloaded memory and to copy the final result from the offloaded memory back to main memory.
- `bool fast_math`: Optimize math functions to use faster approximate versions. There may be slight accuracy degredation when enabled.
Two options may be turned on when compiling:
- `set_offload_copy(bool value)`: For targets with offloaded memory (such as the gpu), this will insert instructions during compilation to copy the input parameters to the offloaded memory and to copy the final result from the offloaded memory back to main memory. Default value is `false` for offload_copy.
- `set_fast_math(bool value)`: Optimize math functions to use faster approximate versions. There may be slight accuracy degredation when enabled. Default value is `true` for fast_math.
The following snippet assumes `targ` has been set as "gpu", and will compile the program without the fast_math optimization.
```
migraphx_compile_options comp_opts;
comp_opts.offload_copy = true;
migraphx::compile_options comp_opts;
comp_opts.set_offload_copy();
prog.compile(targ, comp_opts);
```
......
......@@ -99,8 +99,8 @@ int main(int argc, char** argv)
if(GPU)
{
migraphx_compile_options comp_opts;
comp_opts.offload_copy = true;
migraphx::compile_options comp_opts;
comp_opts.set_offload_copy();
prog.compile(targ, comp_opts);
}
else
......
......@@ -6,7 +6,7 @@ ARG PREFIX=/usr/local
RUN dpkg --add-architecture i386
# Add rocm repository
RUN sh -c 'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/4.2/ xenial main > /etc/apt/sources.list.d/rocm.list'
RUN sh -c 'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/4.5/ ubuntu main > /etc/apt/sources.list.d/rocm.list'
# Install dependencies
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \
......@@ -29,6 +29,8 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
software-properties-common \
wget \
rocm-device-libs \
hip-base \
libnuma-dev \
miopen-hip \
rocblas \
zlib1g-dev && \
......
......@@ -7,6 +7,7 @@ include(CheckCXXLinkerFlag)
add_library(migraphx
adjust_allocation.cpp
analyze_streams.cpp
apply_alpha_beta.cpp
argument.cpp
auto_contiguous.cpp
common.cpp
......@@ -14,7 +15,6 @@ add_library(migraphx
convert_to_json.cpp
cpp_generator.cpp
dead_code_elimination.cpp
decompose.cpp
dom_info.cpp
dynamic_loader.cpp
eliminate_allocation.cpp
......@@ -26,6 +26,7 @@ add_library(migraphx
eliminate_pad.cpp
env.cpp
file_buffer.cpp
fuse_pointwise.cpp
generate.cpp
inline_module.cpp
insert_pad.cpp
......@@ -52,7 +53,6 @@ add_library(migraphx
reduce_dims.cpp
register_op.cpp
register_target.cpp
remap.cpp
simplify_qdq.cpp
rewrite_batchnorm.cpp
rewrite_pooling.cpp
......@@ -131,8 +131,11 @@ register_migraphx_ops(
multibroadcast
multinomial
neg
nonmaxsuppression
nonzero
outline
pad
pointwise
pooling
pow
prefix_scan_sum
......@@ -153,6 +156,7 @@ register_migraphx_ops(
rnn_last_cell_output
rnn_last_hs_output
rnn_var_sl_last_output
roialign
round
rsqrt
scalar
......@@ -198,6 +202,9 @@ target_link_libraries(migraphx PRIVATE -ldl)
target_include_directories(migraphx SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
find_package(Threads)
target_link_libraries(migraphx PUBLIC Threads::Threads)
find_package(msgpack REQUIRED)
target_link_libraries(migraphx PRIVATE msgpackc-cxx)
# Make this available to the tests
......@@ -235,6 +242,7 @@ rocm_export_targets(
TARGETS migraphx::migraphx migraphx_all_targets
NAMESPACE migraphx::
DEPENDS
Threads
${PACKAGE_DEPENDS}
)
......
......@@ -3,7 +3,7 @@ add_library(migraphx_c
api.cpp
)
set_target_properties(migraphx_c PROPERTIES EXPORT_NAME c)
rocm_set_soversion(migraphx_c 2.0)
rocm_set_soversion(migraphx_c 3.0)
rocm_clang_tidy_check(migraphx_c)
target_link_libraries(migraphx_c PRIVATE migraphx migraphx_tf migraphx_onnx migraphx_all_targets)
......
......@@ -13,6 +13,7 @@
#include <migraphx/json.hpp>
#include <migraphx/convert_to_json.hpp>
#include <algorithm>
#include <cstdarg>
namespace migraphx {
......@@ -155,18 +156,30 @@ void quantize_int8_wrap(program& prog, const target& t, quantize_int8_options& o
migraphx::quantize_int8(prog, t, options.calibration, options.op_names);
}
operation create_op(const char* name, const char* attributes)
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wformat-nonliteral"
#endif
operation create_op(const char* name, const char* attributes, va_list vlist)
{
std::string sattributes = attributes == nullptr ? "" : attributes;
std::vector<char> buffer(sattributes.size() * 2);
std::vsnprintf(buffer.data(), buffer.size(), sattributes.c_str(), vlist);
value v = value::object{};
if(attributes != nullptr)
{
v = from_json_string(convert_to_json(std::string(attributes)));
v = from_json_string(convert_to_json(std::string(buffer.data())));
}
auto op = make_op(name, v);
return op;
}
#ifdef __clang__
#pragma clang diagnostic pop
#endif
template <class T>
bool equal(const T& x, const T& y)
{
......@@ -368,7 +381,8 @@ struct migraphx_quantize_int8_options
extern "C" migraphx_status migraphx_shape_destroy(migraphx_shape_t shape)
{
return migraphx::try_([&] { destroy((shape)); });
auto api_error_result = migraphx::try_([&] { destroy((shape)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_shape_create(migraphx_shape_t* shape,
......@@ -376,13 +390,14 @@ extern "C" migraphx_status migraphx_shape_create(migraphx_shape_t* shape,
size_t* lengths,
size_t lengths_size)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(lengths == nullptr and lengths_size != 0)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter lengths: Null pointer");
*shape = object_cast<migraphx_shape_t>(
allocate<migraphx::shape>((migraphx::to_shape_type(type)),
(std::vector<size_t>(lengths, lengths + lengths_size))));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_shape_create_with_strides(migraphx_shape_t* shape,
......@@ -392,7 +407,7 @@ extern "C" migraphx_status migraphx_shape_create_with_strides(migraphx_shape_t*
size_t* strides,
size_t strides_size)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(lengths == nullptr and lengths_size != 0)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter lengths: Null pointer");
if(strides == nullptr and strides_size != 0)
......@@ -402,21 +417,23 @@ extern "C" migraphx_status migraphx_shape_create_with_strides(migraphx_shape_t*
(std::vector<size_t>(lengths, lengths + lengths_size)),
(std::vector<size_t>(strides, strides + strides_size))));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_shape_create_scalar(migraphx_shape_t* shape,
migraphx_shape_datatype_t type)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
*shape = object_cast<migraphx_shape_t>(
allocate<migraphx::shape>((migraphx::to_shape_type(type))));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_shape_lengths(const size_t** out, size_t* out_size, const_migraphx_shape_t shape)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(out == nullptr or out_size == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter out: Null pointer");
if(shape == nullptr)
......@@ -425,12 +442,13 @@ migraphx_shape_lengths(const size_t** out, size_t* out_size, const_migraphx_shap
*out = api_result.data();
*out_size = api_result.size();
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_shape_strides(const size_t** out, size_t* out_size, const_migraphx_shape_t shape)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(out == nullptr or out_size == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter out: Null pointer");
if(shape == nullptr)
......@@ -439,127 +457,141 @@ migraphx_shape_strides(const size_t** out, size_t* out_size, const_migraphx_shap
*out = api_result.data();
*out_size = api_result.size();
});
return api_error_result;
}
extern "C" migraphx_status migraphx_shape_type(migraphx_shape_datatype_t* out,
const_migraphx_shape_t shape)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(out == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter out: Null pointer");
if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
*out = migraphx::to_shape_type((shape->object).type());
});
return api_error_result;
}
extern "C" migraphx_status migraphx_shape_bytes(size_t* out, const_migraphx_shape_t shape)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
*out = (shape->object).bytes();
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_shape_t x)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
if(x == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter x: Null pointer");
*out = migraphx::equal((shape->object), (x->object));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_argument_destroy(migraphx_argument_t argument)
{
return migraphx::try_([&] { destroy((argument)); });
auto api_error_result = migraphx::try_([&] { destroy((argument)); });
return api_error_result;
}
extern "C" migraphx_status
migraphx_argument_create(migraphx_argument_t* argument, const_migraphx_shape_t shape, void* buffer)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
*argument = object_cast<migraphx_argument_t>(
allocate<migraphx::argument>((shape->object), (buffer)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_argument_shape(const_migraphx_shape_t* out,
const_migraphx_argument_t argument)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(argument == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter argument: Null pointer");
*out = object_cast<const_migraphx_shape_t>(&((argument->object).get_shape()));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_argument_buffer(char** out, const_migraphx_argument_t argument)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(argument == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter argument: Null pointer");
*out = (argument->object).data();
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_argument_equal(bool* out, const_migraphx_argument_t argument, const_migraphx_argument_t x)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(argument == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter argument: Null pointer");
if(x == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter x: Null pointer");
*out = migraphx::equal((argument->object), (x->object));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_argument_generate(migraphx_argument_t* out, const_migraphx_shape_t s, size_t seed)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(s == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter s: Null pointer");
*out = allocate<migraphx_argument_t>(migraphx::generate_argument((s->object), (seed)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_target_destroy(migraphx_target_t target)
{
return migraphx::try_([&] { destroy((target)); });
auto api_error_result = migraphx::try_([&] { destroy((target)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_target_create(migraphx_target_t* target, const char* name)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
*target = object_cast<migraphx_target_t>(
allocate<migraphx::target>(migraphx::get_target((name))));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_program_parameter_shapes_destroy(
migraphx_program_parameter_shapes_t program_parameter_shapes)
{
return migraphx::try_([&] { destroy((program_parameter_shapes)); });
auto api_error_result = migraphx::try_([&] { destroy((program_parameter_shapes)); });
return api_error_result;
}
extern "C" migraphx_status
migraphx_program_parameter_shapes_size(size_t* out,
migraphx_program_parameter_shapes_t program_parameter_shapes)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(program_parameter_shapes == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter program_parameter_shapes: Null pointer");
*out = (program_parameter_shapes->object).size();
});
return api_error_result;
}
extern "C" migraphx_status
......@@ -567,19 +599,20 @@ migraphx_program_parameter_shapes_get(const_migraphx_shape_t* out,
migraphx_program_parameter_shapes_t program_parameter_shapes,
const char* name)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(program_parameter_shapes == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter program_parameter_shapes: Null pointer");
*out =
object_cast<const_migraphx_shape_t>(&((program_parameter_shapes->object).at((name))));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_program_parameter_shapes_names(
const char** out, migraphx_program_parameter_shapes_t program_parameter_shapes)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(out == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter out: Null pointer");
if(program_parameter_shapes == nullptr)
......@@ -588,21 +621,24 @@ extern "C" migraphx_status migraphx_program_parameter_shapes_names(
auto&& api_result = migraphx::get_names((program_parameter_shapes->object));
std::copy(api_result.begin(), api_result.end(), out);
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_program_parameters_destroy(migraphx_program_parameters_t program_parameters)
{
return migraphx::try_([&] { destroy((program_parameters)); });
auto api_error_result = migraphx::try_([&] { destroy((program_parameters)); });
return api_error_result;
}
extern "C" migraphx_status
migraphx_program_parameters_create(migraphx_program_parameters_t* program_parameters)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
*program_parameters = object_cast<migraphx_program_parameters_t>(
allocate<std::unordered_map<std::string, migraphx::argument>>());
});
return api_error_result;
}
extern "C" migraphx_status
......@@ -610,7 +646,7 @@ migraphx_program_parameters_add(migraphx_program_parameters_t program_parameters
const char* name,
const_migraphx_argument_t argument)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(program_parameters == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter program_parameters: Null pointer");
......@@ -618,85 +654,95 @@ migraphx_program_parameters_add(migraphx_program_parameters_t program_parameters
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter argument: Null pointer");
(program_parameters->object)[(name)] = (argument->object);
});
return api_error_result;
}
extern "C" migraphx_status migraphx_arguments_destroy(migraphx_arguments_t arguments)
{
return migraphx::try_([&] { destroy((arguments)); });
auto api_error_result = migraphx::try_([&] { destroy((arguments)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_arguments_size(size_t* out, migraphx_arguments_t arguments)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(arguments == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter arguments: Null pointer");
*out = (arguments->object).size();
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_arguments_get(const_migraphx_argument_t* out, migraphx_arguments_t arguments, size_t idx)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(arguments == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter arguments: Null pointer");
*out = object_cast<const_migraphx_argument_t>(&((arguments->object).at((idx))));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_shapes_destroy(migraphx_shapes_t shapes)
{
return migraphx::try_([&] { destroy((shapes)); });
auto api_error_result = migraphx::try_([&] { destroy((shapes)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_shapes_size(size_t* out, migraphx_shapes_t shapes)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(shapes == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shapes: Null pointer");
*out = (shapes->object).size();
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_shapes_get(const_migraphx_shape_t* out, migraphx_shapes_t shapes, size_t idx)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(shapes == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shapes: Null pointer");
*out = object_cast<const_migraphx_shape_t>(&((shapes->object).at((idx))));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_module_print(const_migraphx_module_t module)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(module == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter module: Null pointer");
migraphx::print_module((module->object));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_program_destroy(migraphx_program_t program)
{
return migraphx::try_([&] { destroy((program)); });
auto api_error_result = migraphx::try_([&] { destroy((program)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_program_get_main_module(migraphx_module_t* out,
migraphx_program_t program)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
*out = object_cast<migraphx_module_t>((program->object).get_main_module());
});
return api_error_result;
}
extern "C" migraphx_status migraphx_program_compile(migraphx_program_t program,
migraphx_target_t target,
migraphx_compile_options_t options)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
if(target == nullptr)
......@@ -705,91 +751,105 @@ extern "C" migraphx_status migraphx_program_compile(migraphx_program_t program,
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
(program->object).compile((target->object), (options->object));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_program_get_parameter_shapes(migraphx_program_parameter_shapes_t* out,
migraphx_program_t program)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
*out =
allocate<migraphx_program_parameter_shapes_t>((program->object).get_parameter_shapes());
});
return api_error_result;
}
extern "C" migraphx_status migraphx_program_get_output_shapes(migraphx_shapes_t* out,
migraphx_program_t program)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
*out = allocate<migraphx_shapes_t>(migraphx::get_output_shapes((program->object)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_program_print(const_migraphx_program_t program)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
migraphx::print_program((program->object));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_program_sort(migraphx_program_t program)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
(program->object).sort();
});
return api_error_result;
}
extern "C" migraphx_status migraphx_program_run(migraphx_arguments_t* out,
migraphx_program_t program,
migraphx_program_parameters_t params)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
if(params == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter params: Null pointer");
*out = allocate<migraphx_arguments_t>(migraphx::run((program->object), (params->object)));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_program_equal(bool* out, const_migraphx_program_t program, const_migraphx_program_t x)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
if(x == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter x: Null pointer");
*out = migraphx::equal((program->object), (x->object));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_operation_destroy(migraphx_operation_t operation)
{
return migraphx::try_([&] { destroy((operation)); });
auto api_error_result = migraphx::try_([&] { destroy((operation)); });
return api_error_result;
}
extern "C" migraphx_status
migraphx_operation_create(migraphx_operation_t* operation, const char* name, const char* attributes)
extern "C" migraphx_status migraphx_operation_create(migraphx_operation_t* operation,
const char* name,
const char* attributes,
...)
{
return migraphx::try_([&] {
va_list vlist;
va_start(vlist, attributes);
auto api_error_result = migraphx::try_([&] {
*operation = object_cast<migraphx_operation_t>(
allocate<migraphx::operation>(migraphx::create_op((name), (attributes))));
allocate<migraphx::operation>(migraphx::create_op((name), (attributes), (vlist))));
});
va_end(vlist);
return api_error_result;
}
extern "C" migraphx_status
migraphx_operation_name(char* out, size_t out_size, migraphx_operation_t operation)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(out == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter out: Null pointer");
if(operation == nullptr)
......@@ -798,46 +858,51 @@ migraphx_operation_name(char* out, size_t out_size, migraphx_operation_t operati
auto* it = std::copy_n(api_result.begin(), std::min(api_result.size(), out_size - 1), out);
*it = '\0';
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_load(migraphx_program_t* out, const char* name, migraphx_file_options_t options)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
*out = allocate<migraphx_program_t>(migraphx::load((name), (options->object)));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_save(migraphx_program_t p, const char* name, migraphx_file_options_t options)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(p == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter p: Null pointer");
if(options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
migraphx::save((p->object), (name), (options->object));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_onnx_options_destroy(migraphx_onnx_options_t onnx_options)
{
return migraphx::try_([&] { destroy((onnx_options)); });
auto api_error_result = migraphx::try_([&] { destroy((onnx_options)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_onnx_options_create(migraphx_onnx_options_t* onnx_options)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
*onnx_options = object_cast<migraphx_onnx_options_t>(allocate<migraphx::onnx_options>());
});
return api_error_result;
}
extern "C" migraphx_status migraphx_onnx_options_set_input_parameter_shape(
migraphx_onnx_options_t onnx_options, const char* name, size_t* dims, size_t dims_size)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(onnx_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter onnx_options: Null pointer");
if(dims == nullptr and dims_size != 0)
......@@ -845,96 +910,107 @@ extern "C" migraphx_status migraphx_onnx_options_set_input_parameter_shape(
migraphx::set_input_parameter_shape(
(onnx_options->object), (name), (std::vector<size_t>(dims, dims + dims_size)));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_onnx_options_set_default_dim_value(migraphx_onnx_options_t onnx_options, size_t value)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(onnx_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter onnx_options: Null pointer");
migraphx::set_default_dim_value((onnx_options->object), (value));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_onnx_options_set_default_loop_iterations(migraphx_onnx_options_t onnx_options,
int64_t value)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(onnx_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter onnx_options: Null pointer");
migraphx::set_default_loop_iterations((onnx_options->object), (value));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_file_options_destroy(migraphx_file_options_t file_options)
{
return migraphx::try_([&] { destroy((file_options)); });
auto api_error_result = migraphx::try_([&] { destroy((file_options)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_file_options_create(migraphx_file_options_t* file_options)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
*file_options = object_cast<migraphx_file_options_t>(allocate<migraphx::file_options>());
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_file_options_set_file_format(migraphx_file_options_t file_options, const char* format)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(file_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter file_options: Null pointer");
migraphx::set_file_format((file_options->object), (format));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_compile_options_destroy(migraphx_compile_options_t compile_options)
{
return migraphx::try_([&] { destroy((compile_options)); });
auto api_error_result = migraphx::try_([&] { destroy((compile_options)); });
return api_error_result;
}
extern "C" migraphx_status
migraphx_compile_options_create(migraphx_compile_options_t* compile_options)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
*compile_options =
object_cast<migraphx_compile_options_t>(allocate<migraphx::compile_options>());
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_compile_options_set_offload_copy(migraphx_compile_options_t compile_options, bool value)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(compile_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter compile_options: Null pointer");
migraphx::set_offload_copy((compile_options->object), (value));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_compile_options_set_fast_math(migraphx_compile_options_t compile_options, bool value)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(compile_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter compile_options: Null pointer");
migraphx::set_fast_math((compile_options->object), (value));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_parse_onnx(migraphx_program_t* out, const char* name, migraphx_onnx_options_t options)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
*out = allocate<migraphx_program_t>(migraphx::parse_onnx((name), (options->object)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_parse_onnx_buffer(migraphx_program_t* out,
......@@ -942,40 +1018,44 @@ extern "C" migraphx_status migraphx_parse_onnx_buffer(migraphx_program_t* out,
size_t size,
migraphx_onnx_options_t options)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
*out = allocate<migraphx_program_t>(
migraphx::parse_onnx_buffer((data), (size), (options->object)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_tf_options_destroy(migraphx_tf_options_t tf_options)
{
return migraphx::try_([&] { destroy((tf_options)); });
auto api_error_result = migraphx::try_([&] { destroy((tf_options)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_tf_options_create(migraphx_tf_options_t* tf_options)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
*tf_options = object_cast<migraphx_tf_options_t>(allocate<migraphx::tf_options>());
});
return api_error_result;
}
extern "C" migraphx_status migraphx_tf_options_set_nhwc(migraphx_tf_options_t tf_options,
bool is_nhwc)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(tf_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter tf_options: Null pointer");
migraphx::set_nhwc((tf_options->object), (is_nhwc));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_tf_options_set_input_parameter_shape(
migraphx_tf_options_t tf_options, const char* name, size_t* dims, size_t dims_size)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(tf_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter tf_options: Null pointer");
if(dims == nullptr and dims_size != 0)
......@@ -983,23 +1063,25 @@ extern "C" migraphx_status migraphx_tf_options_set_input_parameter_shape(
migraphx::set_input_parameter_shape(
(tf_options->object), (name), (std::vector<size_t>(dims, dims + dims_size)));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_tf_options_set_default_dim_value(migraphx_tf_options_t tf_options, size_t value)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(tf_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter tf_options: Null pointer");
migraphx::set_default_dim_value((tf_options->object), (value));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_tf_options_set_output_names(migraphx_tf_options_t tf_options,
const char** names,
size_t names_size)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(tf_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter tf_options: Null pointer");
if(names == nullptr and names_size != 0)
......@@ -1007,96 +1089,106 @@ extern "C" migraphx_status migraphx_tf_options_set_output_names(migraphx_tf_opti
migraphx::set_output_names((tf_options->object),
(std::vector<const char*>(names, names + names_size)));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_parse_tf(migraphx_program_t* out, const char* name, migraphx_tf_options_t options)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
*out = allocate<migraphx_program_t>(migraphx::parse_tf((name), (options->object)));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_quantize_op_names_destroy(migraphx_quantize_op_names_t quantize_op_names)
{
return migraphx::try_([&] { destroy((quantize_op_names)); });
auto api_error_result = migraphx::try_([&] { destroy((quantize_op_names)); });
return api_error_result;
}
extern "C" migraphx_status
migraphx_quantize_op_names_create(migraphx_quantize_op_names_t* quantize_op_names)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
*quantize_op_names =
object_cast<migraphx_quantize_op_names_t>(allocate<std::vector<std::string>>());
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_quantize_op_names_add(migraphx_quantize_op_names_t quantize_op_names, const char* name)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(quantize_op_names == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter quantize_op_names: Null pointer");
(quantize_op_names->object).push_back((name));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_quantize_fp16_with_op_names(migraphx_program_t prog,
migraphx_quantize_op_names_t name)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(prog == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter prog: Null pointer");
if(name == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter name: Null pointer");
migraphx::quantize_fp16_with_op_names((prog->object), (name->object));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_quantize_fp16(migraphx_program_t prog)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(prog == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter prog: Null pointer");
migraphx::quantize_fp16((prog->object));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_quantize_int8_options_destroy(migraphx_quantize_int8_options_t quantize_int8_options)
{
return migraphx::try_([&] { destroy((quantize_int8_options)); });
auto api_error_result = migraphx::try_([&] { destroy((quantize_int8_options)); });
return api_error_result;
}
extern "C" migraphx_status
migraphx_quantize_int8_options_create(migraphx_quantize_int8_options_t* quantize_int8_options)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
*quantize_int8_options = object_cast<migraphx_quantize_int8_options_t>(
allocate<migraphx::quantize_int8_options>());
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_quantize_int8_options_add_op_name(migraphx_quantize_int8_options_t quantize_int8_options,
const char* name)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(quantize_int8_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter quantize_int8_options: Null pointer");
migraphx::add_op_name((quantize_int8_options->object), (name));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_quantize_int8_options_add_calibration_data(
migraphx_quantize_int8_options_t quantize_int8_options, migraphx_program_parameters_t data)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(quantize_int8_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter quantize_int8_options: Null pointer");
......@@ -1104,13 +1196,14 @@ extern "C" migraphx_status migraphx_quantize_int8_options_add_calibration_data(
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter data: Null pointer");
migraphx::add_calibration_data((quantize_int8_options->object), (data->object));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_quantize_int8(migraphx_program_t prog,
migraphx_target_t target,
migraphx_quantize_int8_options_t options)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(prog == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter prog: Null pointer");
if(target == nullptr)
......@@ -1119,4 +1212,5 @@ extern "C" migraphx_status migraphx_quantize_int8(migraphx_program_t prog,
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
migraphx::quantize_int8_wrap((prog->object), (target->object), (options->object));
});
return api_error_result;
}
......@@ -209,7 +209,8 @@ migraphx_status migraphx_operation_destroy(migraphx_operation_t operation);
migraphx_status migraphx_operation_create(migraphx_operation_t* operation,
const char* name,
const char* attributes);
const char* attributes,
...);
migraphx_status migraphx_operation_name(char* out, size_t out_size, migraphx_operation_t operation);
......
......@@ -252,7 +252,7 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
const size_t* pout;
size_t pout_size;
call(&migraphx_shape_lengths, &pout, &pout_size, this->get_handle_ptr());
return std::vector<size_t>(pout, pout + pout_size);
return {pout, pout + pout_size};
}
std::vector<size_t> strides() const
......@@ -260,7 +260,7 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
const size_t* pout;
size_t pout_size;
call(&migraphx_shape_strides, &pout, &pout_size, this->get_handle_ptr());
return std::vector<size_t>(pout, pout + pout_size);
return {pout, pout + pout_size};
}
migraphx_shape_datatype_t type() const
......@@ -312,7 +312,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
{
const_migraphx_shape_t pout;
call(&migraphx_argument_shape, &pout, this->get_handle_ptr());
return shape(pout);
return {pout};
}
char* data() const
......@@ -325,9 +325,8 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
/// Generate an argument using random data
static argument generate(shape ps, size_t pseed = 0)
{
return argument(
make<migraphx_argument>(&migraphx_argument_generate, ps.get_handle_ptr(), pseed),
own{});
return {make<migraphx_argument>(&migraphx_argument_generate, ps.get_handle_ptr(), pseed),
own{}};
}
friend bool operator==(const argument& px, const argument& py)
......@@ -378,7 +377,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
{
const_migraphx_shape_t pout;
call(&migraphx_program_parameter_shapes_get, &pout, this->get_handle_ptr(), pname);
return shape(pout);
return {pout};
}
std::vector<const char*> names() const
......@@ -438,7 +437,7 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments>
{
const_migraphx_argument_t pout;
call(&migraphx_arguments_get, &pout, this->get_handle_ptr(), pidx);
return argument(pout);
return {pout};
}
struct iterator_read
......@@ -449,7 +448,7 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments>
const_migraphx_argument_t pout;
call(&migraphx_arguments_get, &pout, self, pidx);
return argument(pout);
return {pout};
}
};
};
......@@ -471,7 +470,7 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
{
const_migraphx_shape_t pout;
call(&migraphx_shapes_get, &pout, this->get_handle_ptr(), pidx);
return shape(pout);
return {pout};
}
struct iterator_read
......@@ -481,7 +480,7 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
{
const_migraphx_shape_t pout;
call(&migraphx_shapes_get, &pout, self, pidx);
return shape(pout);
return {pout};
}
};
};
......@@ -599,16 +598,17 @@ struct operation : MIGRAPHX_HANDLE_BASE(operation)
operation(migraphx_operation* p, borrow) { this->set_handle(p, borrow{}); }
operation(const char* name, const char* attributes = nullptr)
template <class... Ts>
operation(const char* name, const char* attributes = nullptr, Ts... xs)
{
this->make_handle(&migraphx_operation_create, name, attributes);
this->make_handle(&migraphx_operation_create, name, attributes, xs...);
}
std::string name()
{
std::array<char, 1024> out_name;
call(&migraphx_operation_name, out_name.data(), 1024, this->get_handle_ptr());
return std::string(out_name.data());
return {out_name.data()};
}
};
......
......@@ -212,7 +212,9 @@ def program(h):
@auto_handle()
def operation(h):
h.constructor('create',
api.params(name='const char*', attributes='const char*'),
api.params(name='const char*',
attributes='const char*',
vlist='...'),
fname='migraphx::create_op')
h.method('name', returns='std::string')
......
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/common.hpp>
#include <migraphx/apply_alpha_beta.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
instruction_ref insert_apply_alpha_beta(module& m,
instruction_ref pos,
const std::vector<instruction_ref>& args,
const operation& op,
const literal& alpha,
const literal& beta)
{
auto a = args[0];
auto b = args[1];
auto input_type = a->get_shape().type();
if(!float_equal(alpha.at<float>(0), 1.0))
{
auto alpha_literal = m.add_literal(alpha);
a = insert_common_op(m, pos, migraphx::make_op("mul"), {alpha_literal, a});
if(a->get_shape().type() != input_type)
{
a = m.insert_instruction(pos, make_op("convert", {{"target_type", input_type}}), a);
}
}
auto op_res = m.insert_instruction(pos, op, a, b);
if(args.size() == 3)
{
if(not float_equal(beta.at<float>(0), 0.0) && args[2]->get_shape().elements() > 0)
{
auto out_lens = op_res->get_shape().lens();
auto c = args[2];
auto c_lens = c->get_shape().lens();
input_type = c->get_shape().type();
if(out_lens != c_lens)
{
c = m.insert_instruction(
pos, migraphx::make_op("multibroadcast", {{"out_lens", out_lens}}), args[2]);
}
auto beta_literal = m.add_literal(beta);
auto beta_c = insert_common_op(m, pos, migraphx::make_op("mul"), {c, beta_literal});
if(beta_c->get_shape().type() != input_type)
{
beta_c = m.insert_instruction(
pos, migraphx::make_op("convert", {{"target_type", input_type}}), beta_c);
}
return m.insert_instruction(pos, migraphx::make_op("add"), op_res, beta_c);
}
}
return op_res;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment