Commit 0369e974 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

Merge branch 'batch_report' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into mi100_opts

parents 3a474fca d70fd0df
...@@ -80,10 +80,10 @@ jobs: ...@@ -80,10 +80,10 @@ jobs:
uses: pat-s/always-upload-cache@v2.1.3 uses: pat-s/always-upload-cache@v2.1.3
with: with:
path: cppcheck-cache path: cppcheck-cache
key: cppcheck-cache-${{ steps.cache_timestamp.outputs.timestamp }} key: cppcheck-cache-${{ hashFiles('cppcheck.rules', 'CMakeLists.txt') }}-${{ steps.cache_timestamp.outputs.timestamp }}
restore-keys: | restore-keys: |
cppcheck-cache-${{ steps.cache_timestamp.outputs.timestamp }} cppcheck-cache-${{ hashFiles('cppcheck.rules', 'CMakeLists.txt') }}-${{ steps.cache_timestamp.outputs.timestamp }}
cppcheck-cache- cppcheck-cache-${{ hashFiles('cppcheck.rules', 'CMakeLists.txt') }}-
- name: Build the Docker image - name: Build the Docker image
run: docker build . --file hip-clang.docker --tag migraphx run: docker build . --file hip-clang.docker --tag migraphx
......
...@@ -36,7 +36,7 @@ find_package(nlohmann_json 3.8.0 REQUIRED) ...@@ -36,7 +36,7 @@ find_package(nlohmann_json 3.8.0 REQUIRED)
include(ROCMSetupVersion) include(ROCMSetupVersion)
rocm_setup_version(VERSION 1.3) rocm_setup_version(VERSION 2.0)
set(MIGRAPHX_SO_VERSION ${PROJECT_VERSION_MAJOR}.${PROJECT_VERSION_MINOR}) set(MIGRAPHX_SO_VERSION ${PROJECT_VERSION_MAJOR}.${PROJECT_VERSION_MINOR})
option( BUILD_SHARED_LIBS "Build as a shared library" ON ) option( BUILD_SHARED_LIBS "Build as a shared library" ON )
...@@ -190,6 +190,7 @@ rocm_enable_cppcheck( ...@@ -190,6 +190,7 @@ rocm_enable_cppcheck(
shadowVariable shadowVariable
unsafeClassDivZero unsafeClassDivZero
definePrefix:*test/include/test.hpp definePrefix:*test/include/test.hpp
ctuOneDefinitionRuleViolation:*test/*
useSmartPointer:*src/api/api.cpp useSmartPointer:*src/api/api.cpp
useSmartPointer:*make_shared_array.hpp useSmartPointer:*make_shared_array.hpp
constParameter:*src/targets/gpu/*.cpp constParameter:*src/targets/gpu/*.cpp
......
...@@ -6,7 +6,7 @@ ARG PREFIX=/usr/local ...@@ -6,7 +6,7 @@ ARG PREFIX=/usr/local
RUN dpkg --add-architecture i386 RUN dpkg --add-architecture i386
# Add rocm repository # Add rocm repository
RUN sh -c 'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/4.2/ xenial main > /etc/apt/sources.list.d/rocm.list' RUN sh -c 'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/4.5/ ubuntu main > /etc/apt/sources.list.d/rocm.list'
# Install dependencies # Install dependencies
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \
...@@ -32,6 +32,8 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- ...@@ -32,6 +32,8 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
software-properties-common \ software-properties-common \
wget \ wget \
rocm-device-libs \ rocm-device-libs \
hip-base \
libnuma-dev \
miopen-hip \ miopen-hip \
rocblas \ rocblas \
zlib1g-dev && \ zlib1g-dev && \
......
...@@ -20,7 +20,7 @@ def rocmtestnode(Map conf) { ...@@ -20,7 +20,7 @@ def rocmtestnode(Map conf) {
rm -rf build rm -rf build
mkdir build mkdir build
cd build cd build
CXX=${compiler} CXXFLAGS='-Werror -Wno-fallback' cmake -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache ${flags} .. CXX=${compiler} CXXFLAGS='-Werror' cmake -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache ${flags} ..
make -j\$(nproc) generate all doc package check VERBOSE=1 make -j\$(nproc) generate all doc package check VERBOSE=1
""" """
echo cmd echo cmd
...@@ -75,6 +75,8 @@ def rocmnodename(name) { ...@@ -75,6 +75,8 @@ def rocmnodename(name) {
node_name = "${rocmtest_name} && fiji"; node_name = "${rocmtest_name} && fiji";
} else if(name == "vega") { } else if(name == "vega") {
node_name = "${rocmtest_name} && vega"; node_name = "${rocmtest_name} && vega";
} else if(name == "navi21") {
node_name = "${rocmtest_name} && navi21";
} else if(name == "nogpu") { } else if(name == "nogpu") {
return rocmtest_name; return rocmtest_name;
} }
...@@ -110,6 +112,10 @@ rocmtest clang_debug: rocmnode('vega') { cmake_build -> ...@@ -110,6 +112,10 @@ rocmtest clang_debug: rocmnode('vega') { cmake_build ->
def debug_flags = "-g -O2 -fno-omit-frame-pointer -fsanitize=${sanitizers} -fno-sanitize-recover=${sanitizers}" def debug_flags = "-g -O2 -fno-omit-frame-pointer -fsanitize=${sanitizers} -fno-sanitize-recover=${sanitizers}"
cmake_build("/opt/rocm/llvm/bin/clang++", "-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_PYTHON=Off -DMIGRAPHX_ENABLE_GPU=Off -DMIGRAPHX_ENABLE_CPU=On -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}'") cmake_build("/opt/rocm/llvm/bin/clang++", "-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_PYTHON=Off -DMIGRAPHX_ENABLE_GPU=Off -DMIGRAPHX_ENABLE_CPU=On -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}'")
} }
}, clang_release_navi: rocmnode('navi21') { cmake_build ->
stage('HIP Clang Release Navi') {
cmake_build("/opt/rocm/llvm/bin/clang++", "-DCMAKE_BUILD_TYPE=release")
}
} }
def onnxnode(name, body) { def onnxnode(name, body) {
......
...@@ -152,6 +152,24 @@ ...@@ -152,6 +152,24 @@
<summary>Else statement is not necessary.</summary> <summary>Else statement is not necessary.</summary>
</message> </message>
</rule> </rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[((?:(?:\w+|<|>|::) )*(?:\w+|>)(?: &|\*)*) (\w+) ; \2 = static_cast < \1 > (\([^()]*(?-1)*[^()]*\)) ;]]></pattern>
<message>
<id>RedundantCast</id>
<severity>style</severity>
<summary>Static cast is redundant.</summary>
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[auto (\w+) ; \1 = static_cast < (?:(?:\w+|<|>|::) )*(?:\w+|>)(?: &|\*)* > (\([^()]*(?-1)*[^()]*\)) ;]]></pattern>
<message>
<id>RedundantCast</id>
<severity>style</severity>
<summary>Static cast is redundant.</summary>
</message>
</rule>
<rule> <rule>
<tokenlist>normal</tokenlist> <tokenlist>normal</tokenlist>
<pattern><![CDATA[\? (true|false) : (true|false)]]></pattern> <pattern><![CDATA[\? (true|false) : (true|false)]]></pattern>
......
sphinx==2.2.2 docutils==0.17.1
breathe==4.13.1 sphinx==4.2.0
breathe==4.31.0
sphinx_rtd_theme==1.0.0
# git+https://github.com/arximboldi/breathe@fix-node-parent # git+https://github.com/arximboldi/breathe@fix-node-parent
...@@ -18,6 +18,8 @@ ...@@ -18,6 +18,8 @@
# #
# import os # import os
# import sys # import sys
from datetime import date
import re
# sys.path.insert(0, os.path.abspath('.')) # sys.path.insert(0, os.path.abspath('.'))
# -- General configuration ------------------------------------------------ # -- General configuration ------------------------------------------------
...@@ -29,7 +31,9 @@ ...@@ -29,7 +31,9 @@
# Add any Sphinx extension module names here, as strings. They can be # Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones. # ones.
extensions = ['breathe', 'sphinx.ext.mathjax', 'sphinx.ext.viewcode'] extensions = [
'breathe', 'sphinx.ext.mathjax', 'sphinx.ext.viewcode', 'sphinx_rtd_theme'
]
# Add any paths that contain templates here, relative to this directory. # Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates'] templates_path = ['_templates']
...@@ -45,7 +49,7 @@ master_doc = 'index' ...@@ -45,7 +49,7 @@ master_doc = 'index'
# General information about the project. # General information about the project.
project = u'MIGraphX' project = u'MIGraphX'
copyright = u'2018, AMD' copyright = u'2018-{}, AMD'.format(date.today().year)
author = u'AMD' author = u'AMD'
# The version info for the project you're documenting, acts as replacement for # The version info for the project you're documenting, acts as replacement for
...@@ -53,9 +57,12 @@ author = u'AMD' ...@@ -53,9 +57,12 @@ author = u'AMD'
# built documents. # built documents.
# #
# The short X.Y version. # The short X.Y version.
version = u'0.1' with open('../../CMakeLists.txt') as file:
version = next((re.findall('[0-9.]+', line)[0]
for line in file.readlines()
if 'rocm_setup_version' in line))
# The full version, including alpha/beta/rc tags. # The full version, including alpha/beta/rc tags.
release = u'0.1' release = version
# The language for content autogenerated by Sphinx. Refer to documentation # The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages. # for a list of supported languages.
...@@ -82,7 +89,7 @@ todo_include_todos = False ...@@ -82,7 +89,7 @@ todo_include_todos = False
# The theme to use for HTML and HTML Help pages. See the documentation for # The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes. # a list of builtin themes.
# #
html_theme = 'alabaster' html_theme = 'sphinx_rtd_theme'
# Theme options are theme-specific and customize the look and feel of a theme # Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the # further. For a list of options available for each theme, see the
......
...@@ -25,8 +25,8 @@ migraphx::save(p, output_file); ...@@ -25,8 +25,8 @@ migraphx::save(p, output_file);
``` ```
migraphx::program p = ... <migraphx::program>; migraphx::program p = ... <migraphx::program>;
migraphx_file_options options; migraphx::file_options options;
options.format = "msgpack"; options.set_file_format("msgpack");
migraphx::save(p, output_file, options); migraphx::save(p, output_file, options);
``` ```
...@@ -41,15 +41,15 @@ p = migraphx::load(input_file); ...@@ -41,15 +41,15 @@ p = migraphx::load(input_file);
``` ```
migraphx::program p; migraphx::program p;
migraphx_file_options options; migraphx::file_options options;
options.format = "msgpack"; options.set_file_format("msgpack");
p = migraphx::load(input_file, options); p = migraphx::load(input_file, options);
``` ```
To load a program that has been saved in JSON format: To load a program that has been saved in JSON format:
``` ```
migraphx::program p; migraphx::program p;
migraphx_file_options options; migraphx::file_options options;
options.format = "json"; options.set_file_format("json");
p = migraphx::load(input_file, options); p = migraphx::load(input_file, options);
``` ```
......
...@@ -44,15 +44,15 @@ int main(int argc, char** argv) ...@@ -44,15 +44,15 @@ int main(int argc, char** argv)
std::string format = load_arg; std::string format = load_arg;
if(format == "json") if(format == "json")
{ {
migraphx_file_options options; migraphx::file_options options;
options.format = "json"; options.set_file_format("json");
p = migraphx::load(input_file, options); p = migraphx::load(input_file, options);
} }
else if(format == "msgpack") else if(format == "msgpack")
{ {
migraphx_file_options options; migraphx::file_options options;
options.format = "msgpack"; options.set_file_format("msgpack");
p = migraphx::load(input_file, options); p = migraphx::load(input_file, options);
} }
else else
p = migraphx::load(input_file); p = migraphx::load(input_file);
...@@ -80,8 +80,8 @@ int main(int argc, char** argv) ...@@ -80,8 +80,8 @@ int main(int argc, char** argv)
output_file = save_arg == nullptr ? "out" : save_arg; output_file = save_arg == nullptr ? "out" : save_arg;
output_file.append(".msgpack"); output_file.append(".msgpack");
migraphx_file_options options; migraphx::file_options options;
options.format = "msgpack"; options.set_file_format("msgpack");
migraphx::save(p, output_file.c_str(), options); migraphx::save(p, output_file.c_str(), options);
std::cout << "Program has been saved as ./" << output_file << std::endl; std::cout << "Program has been saved as ./" << output_file << std::endl;
} }
......
...@@ -60,14 +60,14 @@ migraphx::quantize_int8(prog, targ, quant_opts); ...@@ -60,14 +60,14 @@ migraphx::quantize_int8(prog, targ, quant_opts);
## Compilation ## Compilation
Network graphs saved in e.g. ONNX or protobuf format are not target-specific. In order to run inference, we must compile the graph into a target-specific program. Network graphs saved in e.g. ONNX or protobuf format are not target-specific. In order to run inference, we must compile the graph into a target-specific program.
Two options may be turned on (default for both is `false`) when compiling: Two options may be turned on when compiling:
- `bool offload_copy`: For targets with offloaded memory (such as the gpu), this will insert instructions during compilation to copy the input parameters to the offloaded memory and to copy the final result from the offloaded memory back to main memory. - `set_offload_copy(bool value)`: For targets with offloaded memory (such as the gpu), this will insert instructions during compilation to copy the input parameters to the offloaded memory and to copy the final result from the offloaded memory back to main memory. Default value is `false` for offload_copy.
- `bool fast_math`: Optimize math functions to use faster approximate versions. There may be slight accuracy degredation when enabled. - `set_fast_math(bool value)`: Optimize math functions to use faster approximate versions. There may be slight accuracy degredation when enabled. Default value is `true` for fast_math.
The following snippet assumes `targ` has been set as "gpu", and will compile the program without the fast_math optimization. The following snippet assumes `targ` has been set as "gpu", and will compile the program without the fast_math optimization.
``` ```
migraphx_compile_options comp_opts; migraphx::compile_options comp_opts;
comp_opts.offload_copy = true; comp_opts.set_offload_copy();
prog.compile(targ, comp_opts); prog.compile(targ, comp_opts);
``` ```
......
...@@ -99,8 +99,8 @@ int main(int argc, char** argv) ...@@ -99,8 +99,8 @@ int main(int argc, char** argv)
if(GPU) if(GPU)
{ {
migraphx_compile_options comp_opts; migraphx::compile_options comp_opts;
comp_opts.offload_copy = true; comp_opts.set_offload_copy();
prog.compile(targ, comp_opts); prog.compile(targ, comp_opts);
} }
else else
......
...@@ -6,7 +6,7 @@ ARG PREFIX=/usr/local ...@@ -6,7 +6,7 @@ ARG PREFIX=/usr/local
RUN dpkg --add-architecture i386 RUN dpkg --add-architecture i386
# Add rocm repository # Add rocm repository
RUN sh -c 'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/4.2/ xenial main > /etc/apt/sources.list.d/rocm.list' RUN sh -c 'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/4.5/ ubuntu main > /etc/apt/sources.list.d/rocm.list'
# Install dependencies # Install dependencies
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \
...@@ -29,6 +29,8 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- ...@@ -29,6 +29,8 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
software-properties-common \ software-properties-common \
wget \ wget \
rocm-device-libs \ rocm-device-libs \
hip-base \
libnuma-dev \
miopen-hip \ miopen-hip \
rocblas \ rocblas \
zlib1g-dev && \ zlib1g-dev && \
......
...@@ -7,6 +7,7 @@ include(CheckCXXLinkerFlag) ...@@ -7,6 +7,7 @@ include(CheckCXXLinkerFlag)
add_library(migraphx add_library(migraphx
adjust_allocation.cpp adjust_allocation.cpp
analyze_streams.cpp analyze_streams.cpp
apply_alpha_beta.cpp
argument.cpp argument.cpp
auto_contiguous.cpp auto_contiguous.cpp
common.cpp common.cpp
...@@ -14,7 +15,6 @@ add_library(migraphx ...@@ -14,7 +15,6 @@ add_library(migraphx
convert_to_json.cpp convert_to_json.cpp
cpp_generator.cpp cpp_generator.cpp
dead_code_elimination.cpp dead_code_elimination.cpp
decompose.cpp
dom_info.cpp dom_info.cpp
dynamic_loader.cpp dynamic_loader.cpp
eliminate_allocation.cpp eliminate_allocation.cpp
...@@ -26,6 +26,7 @@ add_library(migraphx ...@@ -26,6 +26,7 @@ add_library(migraphx
eliminate_pad.cpp eliminate_pad.cpp
env.cpp env.cpp
file_buffer.cpp file_buffer.cpp
fuse_pointwise.cpp
generate.cpp generate.cpp
inline_module.cpp inline_module.cpp
insert_pad.cpp insert_pad.cpp
...@@ -52,7 +53,6 @@ add_library(migraphx ...@@ -52,7 +53,6 @@ add_library(migraphx
reduce_dims.cpp reduce_dims.cpp
register_op.cpp register_op.cpp
register_target.cpp register_target.cpp
remap.cpp
simplify_qdq.cpp simplify_qdq.cpp
rewrite_batchnorm.cpp rewrite_batchnorm.cpp
rewrite_pooling.cpp rewrite_pooling.cpp
...@@ -131,8 +131,11 @@ register_migraphx_ops( ...@@ -131,8 +131,11 @@ register_migraphx_ops(
multibroadcast multibroadcast
multinomial multinomial
neg neg
nonmaxsuppression
nonzero
outline outline
pad pad
pointwise
pooling pooling
pow pow
prefix_scan_sum prefix_scan_sum
...@@ -153,6 +156,7 @@ register_migraphx_ops( ...@@ -153,6 +156,7 @@ register_migraphx_ops(
rnn_last_cell_output rnn_last_cell_output
rnn_last_hs_output rnn_last_hs_output
rnn_var_sl_last_output rnn_var_sl_last_output
roialign
round round
rsqrt rsqrt
scalar scalar
...@@ -198,6 +202,9 @@ target_link_libraries(migraphx PRIVATE -ldl) ...@@ -198,6 +202,9 @@ target_link_libraries(migraphx PRIVATE -ldl)
target_include_directories(migraphx SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>) target_include_directories(migraphx SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
find_package(Threads)
target_link_libraries(migraphx PUBLIC Threads::Threads)
find_package(msgpack REQUIRED) find_package(msgpack REQUIRED)
target_link_libraries(migraphx PRIVATE msgpackc-cxx) target_link_libraries(migraphx PRIVATE msgpackc-cxx)
# Make this available to the tests # Make this available to the tests
...@@ -235,6 +242,7 @@ rocm_export_targets( ...@@ -235,6 +242,7 @@ rocm_export_targets(
TARGETS migraphx::migraphx migraphx_all_targets TARGETS migraphx::migraphx migraphx_all_targets
NAMESPACE migraphx:: NAMESPACE migraphx::
DEPENDS DEPENDS
Threads
${PACKAGE_DEPENDS} ${PACKAGE_DEPENDS}
) )
......
...@@ -3,7 +3,7 @@ add_library(migraphx_c ...@@ -3,7 +3,7 @@ add_library(migraphx_c
api.cpp api.cpp
) )
set_target_properties(migraphx_c PROPERTIES EXPORT_NAME c) set_target_properties(migraphx_c PROPERTIES EXPORT_NAME c)
rocm_set_soversion(migraphx_c 2.0) rocm_set_soversion(migraphx_c 3.0)
rocm_clang_tidy_check(migraphx_c) rocm_clang_tidy_check(migraphx_c)
target_link_libraries(migraphx_c PRIVATE migraphx migraphx_tf migraphx_onnx migraphx_all_targets) target_link_libraries(migraphx_c PRIVATE migraphx migraphx_tf migraphx_onnx migraphx_all_targets)
......
This diff is collapsed.
...@@ -209,7 +209,8 @@ migraphx_status migraphx_operation_destroy(migraphx_operation_t operation); ...@@ -209,7 +209,8 @@ migraphx_status migraphx_operation_destroy(migraphx_operation_t operation);
migraphx_status migraphx_operation_create(migraphx_operation_t* operation, migraphx_status migraphx_operation_create(migraphx_operation_t* operation,
const char* name, const char* name,
const char* attributes); const char* attributes,
...);
migraphx_status migraphx_operation_name(char* out, size_t out_size, migraphx_operation_t operation); migraphx_status migraphx_operation_name(char* out, size_t out_size, migraphx_operation_t operation);
......
...@@ -252,7 +252,7 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape) ...@@ -252,7 +252,7 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
const size_t* pout; const size_t* pout;
size_t pout_size; size_t pout_size;
call(&migraphx_shape_lengths, &pout, &pout_size, this->get_handle_ptr()); call(&migraphx_shape_lengths, &pout, &pout_size, this->get_handle_ptr());
return std::vector<size_t>(pout, pout + pout_size); return {pout, pout + pout_size};
} }
std::vector<size_t> strides() const std::vector<size_t> strides() const
...@@ -260,7 +260,7 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape) ...@@ -260,7 +260,7 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
const size_t* pout; const size_t* pout;
size_t pout_size; size_t pout_size;
call(&migraphx_shape_strides, &pout, &pout_size, this->get_handle_ptr()); call(&migraphx_shape_strides, &pout, &pout_size, this->get_handle_ptr());
return std::vector<size_t>(pout, pout + pout_size); return {pout, pout + pout_size};
} }
migraphx_shape_datatype_t type() const migraphx_shape_datatype_t type() const
...@@ -312,7 +312,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument) ...@@ -312,7 +312,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
{ {
const_migraphx_shape_t pout; const_migraphx_shape_t pout;
call(&migraphx_argument_shape, &pout, this->get_handle_ptr()); call(&migraphx_argument_shape, &pout, this->get_handle_ptr());
return shape(pout); return {pout};
} }
char* data() const char* data() const
...@@ -325,9 +325,8 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument) ...@@ -325,9 +325,8 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
/// Generate an argument using random data /// Generate an argument using random data
static argument generate(shape ps, size_t pseed = 0) static argument generate(shape ps, size_t pseed = 0)
{ {
return argument( return {make<migraphx_argument>(&migraphx_argument_generate, ps.get_handle_ptr(), pseed),
make<migraphx_argument>(&migraphx_argument_generate, ps.get_handle_ptr(), pseed), own{}};
own{});
} }
friend bool operator==(const argument& px, const argument& py) friend bool operator==(const argument& px, const argument& py)
...@@ -378,7 +377,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes) ...@@ -378,7 +377,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
{ {
const_migraphx_shape_t pout; const_migraphx_shape_t pout;
call(&migraphx_program_parameter_shapes_get, &pout, this->get_handle_ptr(), pname); call(&migraphx_program_parameter_shapes_get, &pout, this->get_handle_ptr(), pname);
return shape(pout); return {pout};
} }
std::vector<const char*> names() const std::vector<const char*> names() const
...@@ -438,7 +437,7 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments> ...@@ -438,7 +437,7 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments>
{ {
const_migraphx_argument_t pout; const_migraphx_argument_t pout;
call(&migraphx_arguments_get, &pout, this->get_handle_ptr(), pidx); call(&migraphx_arguments_get, &pout, this->get_handle_ptr(), pidx);
return argument(pout); return {pout};
} }
struct iterator_read struct iterator_read
...@@ -449,7 +448,7 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments> ...@@ -449,7 +448,7 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments>
const_migraphx_argument_t pout; const_migraphx_argument_t pout;
call(&migraphx_arguments_get, &pout, self, pidx); call(&migraphx_arguments_get, &pout, self, pidx);
return argument(pout); return {pout};
} }
}; };
}; };
...@@ -471,7 +470,7 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes> ...@@ -471,7 +470,7 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
{ {
const_migraphx_shape_t pout; const_migraphx_shape_t pout;
call(&migraphx_shapes_get, &pout, this->get_handle_ptr(), pidx); call(&migraphx_shapes_get, &pout, this->get_handle_ptr(), pidx);
return shape(pout); return {pout};
} }
struct iterator_read struct iterator_read
...@@ -481,7 +480,7 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes> ...@@ -481,7 +480,7 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
{ {
const_migraphx_shape_t pout; const_migraphx_shape_t pout;
call(&migraphx_shapes_get, &pout, self, pidx); call(&migraphx_shapes_get, &pout, self, pidx);
return shape(pout); return {pout};
} }
}; };
}; };
...@@ -599,16 +598,17 @@ struct operation : MIGRAPHX_HANDLE_BASE(operation) ...@@ -599,16 +598,17 @@ struct operation : MIGRAPHX_HANDLE_BASE(operation)
operation(migraphx_operation* p, borrow) { this->set_handle(p, borrow{}); } operation(migraphx_operation* p, borrow) { this->set_handle(p, borrow{}); }
operation(const char* name, const char* attributes = nullptr) template <class... Ts>
operation(const char* name, const char* attributes = nullptr, Ts... xs)
{ {
this->make_handle(&migraphx_operation_create, name, attributes); this->make_handle(&migraphx_operation_create, name, attributes, xs...);
} }
std::string name() std::string name()
{ {
std::array<char, 1024> out_name; std::array<char, 1024> out_name;
call(&migraphx_operation_name, out_name.data(), 1024, this->get_handle_ptr()); call(&migraphx_operation_name, out_name.data(), 1024, this->get_handle_ptr());
return std::string(out_name.data()); return {out_name.data()};
} }
}; };
......
...@@ -212,7 +212,9 @@ def program(h): ...@@ -212,7 +212,9 @@ def program(h):
@auto_handle() @auto_handle()
def operation(h): def operation(h):
h.constructor('create', h.constructor('create',
api.params(name='const char*', attributes='const char*'), api.params(name='const char*',
attributes='const char*',
vlist='...'),
fname='migraphx::create_op') fname='migraphx::create_op')
h.method('name', returns='std::string') h.method('name', returns='std::string')
......
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/common.hpp>
#include <migraphx/apply_alpha_beta.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
instruction_ref insert_apply_alpha_beta(module& m,
instruction_ref pos,
const std::vector<instruction_ref>& args,
const operation& op,
const literal& alpha,
const literal& beta)
{
auto a = args[0];
auto b = args[1];
auto input_type = a->get_shape().type();
if(!float_equal(alpha.at<float>(0), 1.0))
{
auto alpha_literal = m.add_literal(alpha);
a = insert_common_op(m, pos, migraphx::make_op("mul"), {alpha_literal, a});
if(a->get_shape().type() != input_type)
{
a = m.insert_instruction(pos, make_op("convert", {{"target_type", input_type}}), a);
}
}
auto op_res = m.insert_instruction(pos, op, a, b);
if(args.size() == 3)
{
if(not float_equal(beta.at<float>(0), 0.0) && args[2]->get_shape().elements() > 0)
{
auto out_lens = op_res->get_shape().lens();
auto c = args[2];
auto c_lens = c->get_shape().lens();
input_type = c->get_shape().type();
if(out_lens != c_lens)
{
c = m.insert_instruction(
pos, migraphx::make_op("multibroadcast", {{"out_lens", out_lens}}), args[2]);
}
auto beta_literal = m.add_literal(beta);
auto beta_c = insert_common_op(m, pos, migraphx::make_op("mul"), {c, beta_literal});
if(beta_c->get_shape().type() != input_type)
{
beta_c = m.insert_instruction(
pos, migraphx::make_op("convert", {{"target_type", input_type}}), beta_c);
}
return m.insert_instruction(pos, migraphx::make_op("add"), op_res, beta_c);
}
}
return op_res;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment