Unverified Commit 7f97b8ef authored by Ted Themistokleous's avatar Ted Themistokleous Committed by GitHub
Browse files

Merge branch 'simplify_1_mul_div_ops' into divide_by_zero_check

parents 2ba401f0 d1fed367
...@@ -53,6 +53,7 @@ jobs: ...@@ -53,6 +53,7 @@ jobs:
CXX=/opt/rocm/llvm/bin/clang++ CC=/opt/rocm/llvm/bin/clang cmake \ CXX=/opt/rocm/llvm/bin/clang++ CC=/opt/rocm/llvm/bin/clang cmake \
-DMIGRAPHX_ENABLE_GPU=On \ -DMIGRAPHX_ENABLE_GPU=On \
-DMIGRAPHX_ENABLE_CPU=On \ -DMIGRAPHX_ENABLE_CPU=On \
-DMIGRAPHX_ENABLE_FPGA=On \
-DROCM_ENABLE_GH_ANNOTATIONS=On \ -DROCM_ENABLE_GH_ANNOTATIONS=On \
-DCLANG_TIDY_DEPEND_ON_TARGET=Off \ -DCLANG_TIDY_DEPEND_ON_TARGET=Off \
-DCLANG_TIDY_CACHE=/data/tidy-cache \ -DCLANG_TIDY_CACHE=/data/tidy-cache \
...@@ -267,7 +268,9 @@ jobs: ...@@ -267,7 +268,9 @@ jobs:
lcov --directory . --capture --output-file $(pwd)/coverage.info lcov --directory . --capture --output-file $(pwd)/coverage.info
lcov --remove $(pwd)/coverage.info '/usr/*' --output-file $(pwd)/coverage.info lcov --remove $(pwd)/coverage.info '/usr/*' --output-file $(pwd)/coverage.info
lcov --list $(pwd)/coverage.info lcov --list $(pwd)/coverage.info
curl -s https://codecov.io/bash | bash curl -Os https://uploader.codecov.io/latest/linux/codecov
chmod +x codecov
./codecov -t ${CODECOV_TOKEN}
echo "Uploaded" echo "Uploaded"
linux-fpga: linux-fpga:
...@@ -363,5 +366,7 @@ jobs: ...@@ -363,5 +366,7 @@ jobs:
# lcov --directory . --capture --output-file $(pwd)/coverage.info # lcov --directory . --capture --output-file $(pwd)/coverage.info
# lcov --remove $(pwd)/coverage.info '/usr/*' --output-file $(pwd)/coverage.info # lcov --remove $(pwd)/coverage.info '/usr/*' --output-file $(pwd)/coverage.info
# lcov --list $(pwd)/coverage.info # lcov --list $(pwd)/coverage.info
# curl -s https://codecov.io/bash | bash # curl -Os https://uploader.codecov.io/latest/linux/codecov
# echo "Uploaded" # chmod +x codecov
\ No newline at end of file # ./codecov -t ${CODECOV_TOKEN}
# echo "Uploaded"
name: MIGraphX Performance Tests name: MIGraphX Performance Tests
on: on:
push:
branches: [develop]
pull_request: pull_request:
branches: [develop] branches: [develop]
types: [opened, synchronize, closed]
schedule: schedule:
- cron: "0 5 * * 1-6" - cron: "0 5 * * 1-6"
...@@ -28,7 +26,7 @@ on: ...@@ -28,7 +26,7 @@ on:
required: true required: true
default: '-s' default: '-s'
concurrency: benchmark concurrency: "perftest-${{ github.head_ref || github.base_ref || 'schedule' }}"
jobs: jobs:
release: release:
......
...@@ -61,11 +61,9 @@ check_type_size("half_float::detail::expr" HALF_EXPR LANGUAGE CXX) ...@@ -61,11 +61,9 @@ check_type_size("half_float::detail::expr" HALF_EXPR LANGUAGE CXX)
set(CMAKE_REQUIRED_INCLUDES) set(CMAKE_REQUIRED_INCLUDES)
set(CMAKE_EXTRA_INCLUDE_FILES) set(CMAKE_EXTRA_INCLUDE_FILES)
find_package(nlohmann_json 3.8.0 REQUIRED)
include(ROCMSetupVersion) include(ROCMSetupVersion)
rocm_setup_version(VERSION 2.3) rocm_setup_version(VERSION 2.4)
set(MIGRAPHX_SO_VERSION ${PROJECT_VERSION_MAJOR}.${PROJECT_VERSION_MINOR}) set(MIGRAPHX_SO_VERSION ${PROJECT_VERSION_MAJOR}.${PROJECT_VERSION_MINOR})
option( BUILD_SHARED_LIBS "Build as a shared library" ON ) option( BUILD_SHARED_LIBS "Build as a shared library" ON )
...@@ -214,6 +212,7 @@ rocm_enable_cppcheck( ...@@ -214,6 +212,7 @@ rocm_enable_cppcheck(
ConfigurationNotChecked ConfigurationNotChecked
unmatchedSuppression unmatchedSuppression
unusedFunction unusedFunction
ctuPointerArith
noExplicitConstructor noExplicitConstructor
passedByValue passedByValue
unusedStructMember unusedStructMember
......
...@@ -77,7 +77,7 @@ RUN cget -p $PREFIX install ccache@v4.1 ...@@ -77,7 +77,7 @@ RUN cget -p $PREFIX install ccache@v4.1
RUN cget -p /opt/cmake install kitware/cmake@v3.13.4 RUN cget -p /opt/cmake install kitware/cmake@v3.13.4
ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime
ARG ONNXRUNTIME_BRANCH=master ARG ONNXRUNTIME_BRANCH=main
ARG ONNXRUNTIME_COMMIT=24f1bd6156cf5968bbc76dfb0e801a9b9c56b9fc ARG ONNXRUNTIME_COMMIT=24f1bd6156cf5968bbc76dfb0e801a9b9c56b9fc
RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime && \ RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime && \
cd onnxruntime && \ cd onnxruntime && \
...@@ -86,7 +86,7 @@ RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXR ...@@ -86,7 +86,7 @@ RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXR
ADD tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh ADD tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh
RUN cget -p /usr/local install ROCmSoftwarePlatform/llvm-project-mlir@26a4b3cfc0a1a15181490f24ae461608fef1b04e -DBUILD_MIXR_TARGET=On RUN cget -p /usr/local install ROCmSoftwarePlatform/llvm-project-mlir@e8e77eb16be413d301ea8509726d47f265d9011f -DBUILD_MIXR_TARGET=On
ENV MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db ENV MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db
ENV MIOPEN_USER_DB_PATH=/tmp/miopen/user-db ENV MIOPEN_USER_DB_PATH=/tmp/miopen/user-db
......
...@@ -33,7 +33,7 @@ def rocmtestnode(Map conf) { ...@@ -33,7 +33,7 @@ def rocmtestnode(Map conf) {
} }
} }
node(name) { node(name) {
withEnv(['HSA_ENABLE_SDMA=0', 'MIOPEN_DEBUG_GCN_ASM_KERNELS=0']) { withEnv(['HSA_ENABLE_SDMA=0']) {
stage("checkout ${variant}") { stage("checkout ${variant}") {
checkout scm checkout scm
} }
......
...@@ -46,6 +46,7 @@ The following is a list of prerequisites required to build MIGraphX source. ...@@ -46,6 +46,7 @@ The following is a list of prerequisites required to build MIGraphX source.
* [pybind11](https://pybind11.readthedocs.io/en/stable/) - for python bindings * [pybind11](https://pybind11.readthedocs.io/en/stable/) - for python bindings
* [JSON](https://github.com/nlohmann/json) - for model serialization to json string format * [JSON](https://github.com/nlohmann/json) - for model serialization to json string format
* [MessagePack](https://msgpack.org/index.html) - for model serialization to binary format * [MessagePack](https://msgpack.org/index.html) - for model serialization to binary format
* [SQLite3](https://www.sqlite.org/index.html) - to create database of kernels' tuning information or execute queries on existing database
#### Use the ROCm build tool [rbuild](https://github.com/RadeonOpenCompute/rbuild). #### Use the ROCm build tool [rbuild](https://github.com/RadeonOpenCompute/rbuild).
......
...@@ -107,7 +107,7 @@ ...@@ -107,7 +107,7 @@
<summary>Use make_shared or make_unique instead of new</summary> <summary>Use make_shared or make_unique instead of new</summary>
</message> </message>
</rule> </rule>
<!-- <rule> <rule>
<tokenlist>raw</tokenlist> <tokenlist>raw</tokenlist>
<pattern><![CDATA[ \|\| ]]></pattern> <pattern><![CDATA[ \|\| ]]></pattern>
<message> <message>
...@@ -124,7 +124,7 @@ ...@@ -124,7 +124,7 @@
<severity>style</severity> <severity>style</severity>
<summary>Use 'not' instead of !</summary> <summary>Use 'not' instead of !</summary>
</message> </message>
</rule> --> </rule>
<!-- <rule> <!-- <rule>
<tokenlist>raw</tokenlist> <tokenlist>raw</tokenlist>
<pattern><![CDATA[] (__device__ |__host__ )+(\(|{)]]></pattern> <pattern><![CDATA[] (__device__ |__host__ )+(\(|{)]]></pattern>
......
...@@ -25,6 +25,6 @@ pfultz2/rocm-recipes ...@@ -25,6 +25,6 @@ pfultz2/rocm-recipes
facebook/zstd@v1.4.5 -X subdir -DCMAKE_DIR=build/cmake facebook/zstd@v1.4.5 -X subdir -DCMAKE_DIR=build/cmake
ccache@v4.1 ccache@v4.1
pcre,pfultz2/pcre@8.45 -H sha256:d6f7182602a775a7d500a0cedca6449af0400c6493951513046d17615ed0bf11 pcre,pfultz2/pcre@8.45 -H sha256:d6f7182602a775a7d500a0cedca6449af0400c6493951513046d17615ed0bf11
danmar/cppcheck@2.8 -DHAVE_RULES=1 danmar/cppcheck@2.9 -DHAVE_RULES=1
RadeonOpenCompute/rocm-cmake@1ebf7e7bc61bb5e949c171562b421264065230a7 --build RadeonOpenCompute/rocm-cmake@1ebf7e7bc61bb5e949c171562b421264065230a7 --build
-f requirements.txt -f requirements.txt
...@@ -84,6 +84,12 @@ argument ...@@ -84,6 +84,12 @@ argument
Construct an argument from a python buffer. This can include numpy arrays. Construct an argument from a python buffer. This can include numpy arrays.
.. py:method:: data_ptr()
Returns the address to the underlying argument data.
:rtype: int
.. py:method:: get_shape() .. py:method:: get_shape()
Returns the shape of the argument. Returns the shape of the argument.
...@@ -113,7 +119,16 @@ argument ...@@ -113,7 +119,16 @@ argument
:param shape s: Shape of argument to fill. :param shape s: Shape of argument to fill.
:param int value: Value to fill in the argument. :param int value: Value to fill in the argument.
:rtype argument :rtype: argument
.. py:function:: argument_from_pointer(shape, address)
Create argument from data stored in given address without copy.
:param shape shape: Shape of the data stored in address.
:param long address: Memory address of data from another source
:rtype: argument
target target
------ ------
......
...@@ -53,8 +53,8 @@ int main(int argc, char** argv) ...@@ -53,8 +53,8 @@ int main(int argc, char** argv)
migraphx::program p; migraphx::program p;
if(cmdOptionExists(argv + 2, argv + argc, "--parse") || if(cmdOptionExists(argv + 2, argv + argc, "--parse") or
!cmdOptionExists(argv + 2, argv + argc, "--load")) not cmdOptionExists(argv + 2, argv + argc, "--load"))
{ {
std::cout << "Parsing ONNX File" << std::endl; std::cout << "Parsing ONNX File" << std::endl;
migraphx::onnx_options options; migraphx::onnx_options options;
......
...@@ -45,6 +45,13 @@ __global__ void vector_square(T* C_d, const T* A_d, size_t N) ...@@ -45,6 +45,13 @@ __global__ void vector_square(T* C_d, const T* A_d, size_t N)
struct square_custom_op final : migraphx::experimental_custom_op_base struct square_custom_op final : migraphx::experimental_custom_op_base
{ {
virtual std::string name() const override { return "square_custom_op"; } virtual std::string name() const override { return "square_custom_op"; }
// flag to identify whether custom op runs on the GPU or on the host.
// Based on this flag MIGraphX would inject necessary copies to and from GPU for the input and
// output buffers as necessary. Therefore if custom_op runs on GPU then it can assume its input
// buffers are in GPU memory, and similarly for the host
virtual bool runs_on_offload_target() const override { return true; }
virtual migraphx::argument virtual migraphx::argument
compute(migraphx::context ctx, migraphx::shape, migraphx::arguments inputs) const override compute(migraphx::context ctx, migraphx::shape, migraphx::arguments inputs) const override
{ {
...@@ -54,7 +61,7 @@ struct square_custom_op final : migraphx::experimental_custom_op_base ...@@ -54,7 +61,7 @@ struct square_custom_op final : migraphx::experimental_custom_op_base
// is output argument, so it should be returned from compute method. // is output argument, so it should be returned from compute method.
auto* input_buffer = reinterpret_cast<float*>(inputs[0].data()); auto* input_buffer = reinterpret_cast<float*>(inputs[0].data());
auto* output_buffer = reinterpret_cast<float*>(inputs[1].data()); auto* output_buffer = reinterpret_cast<float*>(inputs[1].data());
size_t n_elements = inputs[0].get_shape().bytes() / sizeof(inputs[0].get_shape().type()); size_t n_elements = inputs[0].get_shape().elements();
MIGRAPHX_HIP_ASSERT(hipSetDevice(0)); MIGRAPHX_HIP_ASSERT(hipSetDevice(0));
const unsigned blocks = 512; const unsigned blocks = 512;
const unsigned threads_per_block = 256; const unsigned threads_per_block = 256;
...@@ -103,7 +110,7 @@ int main(int argc, const char* argv[]) ...@@ -103,7 +110,7 @@ int main(int argc, const char* argv[])
options.set_offload_copy(); options.set_offload_copy();
p.compile(migraphx::target("gpu"), options); p.compile(migraphx::target("gpu"), options);
migraphx::program_parameters pp; migraphx::program_parameters pp;
std::vector<float> x_data(s.bytes() / sizeof(s.type())); std::vector<float> x_data(s.elements());
std::iota(x_data.begin(), x_data.end(), 0); std::iota(x_data.begin(), x_data.end(), 0);
pp.add("x", migraphx::argument(s, x_data.data())); pp.add("x", migraphx::argument(s, x_data.data()));
auto results = p.eval(pp); auto results = p.eval(pp);
......
...@@ -93,6 +93,13 @@ inline auto make_activation_descriptor(miopenActivationMode_t mode, ...@@ -93,6 +93,13 @@ inline auto make_activation_descriptor(miopenActivationMode_t mode,
struct abs_custom_op final : migraphx::experimental_custom_op_base struct abs_custom_op final : migraphx::experimental_custom_op_base
{ {
virtual std::string name() const override { return "abs_custom_op"; } virtual std::string name() const override { return "abs_custom_op"; }
// flag to identify whether custom op runs on the GPU or on the host.
// Based on this flag MIGraphX would inject necessary copies to and from GPU for the input and
// output buffers as necessary. Therefore if custom_op runs on GPU then it can assume its input
// buffers are in GPU memory, and similarly for the host
virtual bool runs_on_offload_target() const override { return true; }
virtual migraphx::argument compute(migraphx::context ctx, virtual migraphx::argument compute(migraphx::context ctx,
migraphx::shape output_shape, migraphx::shape output_shape,
migraphx::arguments args) const override migraphx::arguments args) const override
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
*/ */
#include <algorithm> #include <algorithm>
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include <rocblas.h> #include <rocblas/rocblas.h>
#include <migraphx/migraphx.h> #include <migraphx/migraphx.h>
#include <migraphx/migraphx.hpp> // MIGraphX's C++ API #include <migraphx/migraphx.hpp> // MIGraphX's C++ API
#include <numeric> #include <numeric>
...@@ -51,16 +51,25 @@ rocblas_handle create_rocblas_handle_ptr(migraphx::context& ctx) ...@@ -51,16 +51,25 @@ rocblas_handle create_rocblas_handle_ptr(migraphx::context& ctx)
struct sscal_custom_op final : migraphx::experimental_custom_op_base struct sscal_custom_op final : migraphx::experimental_custom_op_base
{ {
virtual std::string name() const override { return "sscal_custom_op"; } virtual std::string name() const override { return "sscal_custom_op"; }
// flag to identify whether custom op runs on the GPU or on the host.
// Based on this flag MIGraphX would inject necessary copies to and from GPU for the input and
// output buffers as necessary. Therefore if custom_op runs on GPU then it can assume its input
// buffers are in GPU memory, and similarly for the host
virtual bool runs_on_offload_target() const override { return true; }
virtual migraphx::argument compute(migraphx::context ctx, virtual migraphx::argument compute(migraphx::context ctx,
migraphx::shape output_shape, migraphx::shape output_shape,
migraphx::arguments args) const override migraphx::arguments args) const override
{ {
// create rocblas stream handle // create rocblas stream handle
auto rocblas_handle = create_rocblas_handle_ptr(ctx); auto rb_handle = create_rocblas_handle_ptr(ctx);
rocblas_int n = args[1].get_shape().lengths()[0]; MIGRAPHX_ROCBLAS_ASSERT(rocblas_set_pointer_mode(rb_handle, rocblas_pointer_mode_device));
float* alpha = reinterpret_cast<float*>(args[0].data()); rocblas_int n = args[1].get_shape().lengths()[0];
float* vec_ptr = reinterpret_cast<float*>(args[1].data()); float* alpha = reinterpret_cast<float*>(args[0].data());
MIGRAPHX_ROCBLAS_ASSERT(rocblas_sscal(rocblas_handle, n, alpha, vec_ptr, 1)); float* vec_ptr = reinterpret_cast<float*>(args[1].data());
MIGRAPHX_ROCBLAS_ASSERT(rocblas_sscal(rb_handle, n, alpha, vec_ptr, 1));
MIGRAPHX_ROCBLAS_ASSERT(rocblas_destroy_handle(rb_handle));
return args[1]; return args[1];
} }
...@@ -70,7 +79,7 @@ struct sscal_custom_op final : migraphx::experimental_custom_op_base ...@@ -70,7 +79,7 @@ struct sscal_custom_op final : migraphx::experimental_custom_op_base
{ {
throw std::runtime_error("sscal_custom_op must have 2 input arguments"); throw std::runtime_error("sscal_custom_op must have 2 input arguments");
} }
if(inputs[0].lengths().size() != 1 || inputs[0].lengths()[0] != 1) if(inputs[0].lengths().size() != 1 or inputs[0].lengths()[0] != 1)
{ {
throw std::runtime_error("first input argument to sscal_custom_op must be a scalar"); throw std::runtime_error("first input argument to sscal_custom_op must be a scalar");
} }
...@@ -105,7 +114,7 @@ int main(int argc, const char* argv[]) ...@@ -105,7 +114,7 @@ int main(int argc, const char* argv[])
options.set_offload_copy(); options.set_offload_copy();
p.compile(migraphx::target("gpu"), options); p.compile(migraphx::target("gpu"), options);
migraphx::program_parameters pp; migraphx::program_parameters pp;
std::vector<float> x_data(x_shape.bytes() / sizeof(x_shape.type())); std::vector<float> x_data(x_shape.elements());
std::vector<float> scale_data{-1}; std::vector<float> scale_data{-1};
std::iota(x_data.begin(), x_data.end(), 0); std::iota(x_data.begin(), x_data.end(), 0);
pp.add("x", migraphx::argument(x_shape, x_data.data())); pp.add("x", migraphx::argument(x_shape, x_data.data()));
......
...@@ -51,16 +51,16 @@ int main(int argc, char** argv) ...@@ -51,16 +51,16 @@ int main(int argc, char** argv)
char** begin = argv + 1; char** begin = argv + 1;
char** end = argv + argc; char** end = argv + argc;
const bool CPU = (std::find(begin, end, std::string("-c")) != end) || const bool CPU = (std::find(begin, end, std::string("-c")) != end) or
std::find(begin, end, std::string("--cpu")) != end; std::find(begin, end, std::string("--cpu")) != end;
const bool GPU = std::find(begin, end, std::string("-g")) != end || const bool GPU = std::find(begin, end, std::string("-g")) != end or
std::find(begin, end, std::string("--gpu")) != end; std::find(begin, end, std::string("--gpu")) != end;
const bool FP16 = std::find(begin, end, std::string("-f")) != end || const bool FP16 = std::find(begin, end, std::string("-f")) != end or
std::find(begin, end, std::string("--fp16")) != end; std::find(begin, end, std::string("--fp16")) != end;
const bool INT8 = std::find(begin, end, std::string("-i")) != end || const bool INT8 = std::find(begin, end, std::string("-i")) != end or
std::find(begin, end, std::string("--int8")) != end; std::find(begin, end, std::string("--int8")) != end;
const bool CALIB = std::find(begin, end, std::string("--cal")) != end; const bool CALIB = std::find(begin, end, std::string("--cal")) != end;
const bool PRINT = std::find(begin, end, std::string("-p")) != end || const bool PRINT = std::find(begin, end, std::string("-p")) != end or
std::find(begin, end, std::string("--print")) != end; std::find(begin, end, std::string("--print")) != end;
migraphx::program prog; migraphx::program prog;
...@@ -182,7 +182,7 @@ void read_nth_digit(const int n, std::vector<float>& digit) ...@@ -182,7 +182,7 @@ void read_nth_digit(const int n, std::vector<float>& digit)
const int HEIGHT = 28; const int HEIGHT = 28;
const int WIDTH = 28; const int WIDTH = 28;
if(!file.is_open()) if(not file.is_open())
{ {
return; return;
} }
......
...@@ -27,3 +27,4 @@ live-clones/blaze@v3.8 -X header -DHEADER_DIR=blaze -H sha256:d0ff011f4753828517 ...@@ -27,3 +27,4 @@ live-clones/blaze@v3.8 -X header -DHEADER_DIR=blaze -H sha256:d0ff011f4753828517
half,https://github.com/pfultz2/half/archive/1.12.0.tar.gz -X header -H sha256:0a08660b68abb176ebc2a0cdf8de46e3182a7f46c66443bb80dbfaaec98cf969 half,https://github.com/pfultz2/half/archive/1.12.0.tar.gz -X header -H sha256:0a08660b68abb176ebc2a0cdf8de46e3182a7f46c66443bb80dbfaaec98cf969
pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build
msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off
sqlite3@3.17 -DCMAKE_POSITION_INDEPENDENT_CODE=On
...@@ -65,6 +65,7 @@ add_library(migraphx ...@@ -65,6 +65,7 @@ add_library(migraphx
operation.cpp operation.cpp
opt/memory_coloring.cpp opt/memory_coloring.cpp
opt/memory_coloring_impl.cpp opt/memory_coloring_impl.cpp
pad_calc.cpp
pass_manager.cpp pass_manager.cpp
permutation.cpp permutation.cpp
preallocate_param.cpp preallocate_param.cpp
...@@ -79,7 +80,9 @@ add_library(migraphx ...@@ -79,7 +80,9 @@ add_library(migraphx
register_target.cpp register_target.cpp
replace_allocate.cpp replace_allocate.cpp
simplify_qdq.cpp simplify_qdq.cpp
sqlite.cpp
rewrite_batchnorm.cpp rewrite_batchnorm.cpp
rewrite_gelu.cpp
rewrite_pooling.cpp rewrite_pooling.cpp
rewrite_quantization.cpp rewrite_quantization.cpp
rewrite_rnn.cpp rewrite_rnn.cpp
...@@ -88,7 +91,6 @@ add_library(migraphx ...@@ -88,7 +91,6 @@ add_library(migraphx
shape.cpp shape.cpp
simplify_algebra.cpp simplify_algebra.cpp
simplify_reshapes.cpp simplify_reshapes.cpp
target_assignments.cpp
tmp_dir.cpp tmp_dir.cpp
value.cpp value.cpp
verify_args.cpp verify_args.cpp
...@@ -134,6 +136,7 @@ register_migraphx_ops( ...@@ -134,6 +136,7 @@ register_migraphx_ops(
exp exp
flatten flatten
floor floor
fmod
gather gather
gathernd gathernd
get_tuple_elem get_tuple_elem
...@@ -156,6 +159,7 @@ register_migraphx_ops( ...@@ -156,6 +159,7 @@ register_migraphx_ops(
lstm lstm
max max
min min
mod
mul mul
multibroadcast multibroadcast
multinomial multinomial
...@@ -239,6 +243,13 @@ target_include_directories(migraphx SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLU ...@@ -239,6 +243,13 @@ target_include_directories(migraphx SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLU
find_package(Threads) find_package(Threads)
target_link_libraries(migraphx PUBLIC Threads::Threads) target_link_libraries(migraphx PUBLIC Threads::Threads)
find_package(nlohmann_json 3.8.0 REQUIRED)
target_link_libraries(migraphx PRIVATE nlohmann_json::nlohmann_json)
find_package(PkgConfig)
pkg_check_modules(SQLITE3 REQUIRED IMPORTED_TARGET sqlite3)
target_link_libraries(migraphx PRIVATE PkgConfig::SQLITE3)
find_package(msgpack REQUIRED) find_package(msgpack REQUIRED)
target_link_libraries(migraphx PRIVATE msgpackc-cxx) target_link_libraries(migraphx PRIVATE msgpackc-cxx)
# Make this available to the tests # Make this available to the tests
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/execution_environment.hpp>
#include <migraphx/migraphx.h> #include <migraphx/migraphx.h>
#include <migraphx/rank.hpp> #include <migraphx/rank.hpp>
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
...@@ -166,6 +167,13 @@ void set_output_names(tf_options& options, std::vector<const char*> names) ...@@ -166,6 +167,13 @@ void set_output_names(tf_options& options, std::vector<const char*> names)
options.output_node_names = std::vector<std::string>(names.begin(), names.end()); options.output_node_names = std::vector<std::string>(names.begin(), names.end());
} }
std::vector<argument>
run_async(program& p, const parameter_map& params, void* s, std::string_view name)
{
execution_environment exec_env{any_ptr(s, name), true};
return p.eval(params, exec_env);
}
template <class Value> template <class Value>
std::vector<const char*> get_names(const std::unordered_map<std::string, Value>& m) std::vector<const char*> get_names(const std::unordered_map<std::string, Value>& m)
{ {
...@@ -265,11 +273,18 @@ struct experimental_custom_op ...@@ -265,11 +273,18 @@ struct experimental_custom_op
template <class CustomOp> template <class CustomOp>
struct custom_operation struct custom_operation
{ {
template <class Self, class F> template <class Self, class F>
static auto reflect(Self&, F) static auto reflect(Self&, F)
{ {
return pack(); return pack();
} }
value attributes() const
{
return {{"custom_op", true}, {"target", op.runs_on_offload_target() ? "gpu" : "cpu"}};
}
CustomOp op; CustomOp op;
std::string name() const { return op.xobject.name; } std::string name() const { return op.xobject.name; }
...@@ -284,6 +299,23 @@ struct custom_operation ...@@ -284,6 +299,23 @@ struct custom_operation
{ {
return op.compute(std::move(ctx), std::move(output_shape), std::move(inputs)); return op.compute(std::move(ctx), std::move(output_shape), std::move(inputs));
} }
std::ptrdiff_t output_alias(std::vector<shape> inputs) const
{
auto alias_vec = op.output_alias(std::move(inputs));
// TODO: For now, only support one output alias
if(alias_vec.empty())
{
return -1;
}
if(alias_vec.size() > 1)
{
MIGRAPHX_THROW("Currently, CustomOps in MIGraphX only supports one output_alias");
}
return alias_vec.front();
}
bool runs_on_offload_target() const { return op.runs_on_offload_target(); }
}; };
template <class CustomOp> template <class CustomOp>
...@@ -613,9 +645,9 @@ struct migraphx_experimental_custom_op ...@@ -613,9 +645,9 @@ struct migraphx_experimental_custom_op
migraphx::shape output, migraphx::shape output,
std::vector<migraphx::argument> inputs) const std::vector<migraphx::argument> inputs) const
{ {
std::remove_pointer_t<migraphx_argument_t> out;
if(compute_f == nullptr) if(compute_f == nullptr)
throw std::runtime_error("compute function is missing."); throw std::runtime_error("compute function is missing.");
std::remove_pointer_t<migraphx_argument_t> out;
std::array<char, 256> exception_msg; std::array<char, 256> exception_msg;
exception_msg.front() = '\0'; exception_msg.front() = '\0';
auto api_error_result = compute_f(&out, auto api_error_result = compute_f(&out,
...@@ -637,9 +669,9 @@ struct migraphx_experimental_custom_op ...@@ -637,9 +669,9 @@ struct migraphx_experimental_custom_op
migraphx_experimental_custom_op_compute_shape compute_shape_f = nullptr; migraphx_experimental_custom_op_compute_shape compute_shape_f = nullptr;
migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
{ {
std::remove_pointer_t<migraphx_shape_t> out;
if(compute_shape_f == nullptr) if(compute_shape_f == nullptr)
throw std::runtime_error("compute_shape function is missing."); throw std::runtime_error("compute_shape function is missing.");
std::remove_pointer_t<migraphx_shape_t> out;
std::array<char, 256> exception_msg; std::array<char, 256> exception_msg;
exception_msg.front() = '\0'; exception_msg.front() = '\0';
auto api_error_result = compute_shape_f(&out, auto api_error_result = compute_shape_f(&out,
...@@ -655,6 +687,49 @@ struct migraphx_experimental_custom_op ...@@ -655,6 +687,49 @@ struct migraphx_experimental_custom_op
} }
return (&out)->object; return (&out)->object;
} }
migraphx_experimental_custom_op_output_alias output_alias_f = nullptr;
std::vector<size_t> output_alias(std::vector<migraphx::shape> inputs) const
{
if(output_alias_f == nullptr)
throw std::runtime_error("output_alias function is missing.");
std::array<size_t, 1024> out;
std::remove_pointer_t<size_t*> out_size = 1024;
std::array<char, 256> exception_msg;
exception_msg.front() = '\0';
auto api_error_result = output_alias_f(out.data(),
&out_size,
object_ptr.data,
exception_msg.data(),
exception_msg.size(),
object_cast<migraphx_shapes_t>(&(inputs)));
if(api_error_result != migraphx_status_success)
{
const std::string exception_str(exception_msg.data());
throw std::runtime_error("Error in output_alias of: " +
std::string(object_ptr.obj_typename) + ": " + exception_str);
}
return {out.begin(), out.begin() + out_size}; // cppcheck-suppress returnDanglingLifetime;
}
migraphx_experimental_custom_op_runs_on_offload_target runs_on_offload_target_f = nullptr;
bool runs_on_offload_target() const
{
if(runs_on_offload_target_f == nullptr)
throw std::runtime_error("runs_on_offload_target function is missing.");
std::remove_pointer_t<bool*> out;
std::array<char, 256> exception_msg;
exception_msg.front() = '\0';
auto api_error_result = runs_on_offload_target_f(
&out, object_ptr.data, exception_msg.data(), exception_msg.size());
if(api_error_result != migraphx_status_success)
{
const std::string exception_str(exception_msg.data());
throw std::runtime_error("Error in runs_on_offload_target of: " +
std::string(object_ptr.obj_typename) + ": " + exception_str);
}
return out;
}
}; };
extern "C" migraphx_status migraphx_shape_destroy(migraphx_shape_t shape) extern "C" migraphx_status migraphx_shape_destroy(migraphx_shape_t shape)
...@@ -758,6 +833,16 @@ extern "C" migraphx_status migraphx_shape_type(migraphx_shape_datatype_t* out, ...@@ -758,6 +833,16 @@ extern "C" migraphx_status migraphx_shape_type(migraphx_shape_datatype_t* out,
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_shape_elements(size_t* out, const_migraphx_shape_t shape)
{
auto api_error_result = migraphx::try_([&] {
if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
*out = (shape->object).elements();
});
return api_error_result;
}
extern "C" migraphx_status migraphx_shape_bytes(size_t* out, const_migraphx_shape_t shape) extern "C" migraphx_status migraphx_shape_bytes(size_t* out, const_migraphx_shape_t shape)
{ {
auto api_error_result = migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
...@@ -791,6 +876,16 @@ extern "C" migraphx_status migraphx_shape_standard(bool* out, const_migraphx_sha ...@@ -791,6 +876,16 @@ extern "C" migraphx_status migraphx_shape_standard(bool* out, const_migraphx_sha
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_shape_index(size_t* out, const_migraphx_shape_t shape, size_t i)
{
auto api_error_result = migraphx::try_([&] {
if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
*out = (shape->object).index((i));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_argument_destroy(migraphx_argument_t argument) extern "C" migraphx_status migraphx_argument_destroy(migraphx_argument_t argument)
{ {
auto api_error_result = migraphx::try_([&] { destroy((argument)); }); auto api_error_result = migraphx::try_([&] { destroy((argument)); });
...@@ -1348,6 +1443,23 @@ extern "C" migraphx_status migraphx_program_run(migraphx_arguments_t* out, ...@@ -1348,6 +1443,23 @@ extern "C" migraphx_status migraphx_program_run(migraphx_arguments_t* out,
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_program_run_async(migraphx_arguments_t* out,
migraphx_program_t program,
migraphx_program_parameters_t params,
void* s,
const char* name)
{
auto api_error_result = migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
if(params == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter params: Null pointer");
*out = allocate<migraphx_arguments_t>(
migraphx::run_async((program->object), (params->object), (s), (name)));
});
return api_error_result;
}
extern "C" migraphx_status extern "C" migraphx_status
migraphx_program_equal(bool* out, const_migraphx_program_t program, const_migraphx_program_t x) migraphx_program_equal(bool* out, const_migraphx_program_t program, const_migraphx_program_t x)
{ {
...@@ -1879,6 +1991,22 @@ extern "C" migraphx_status migraphx_experimental_custom_op_set_compute_shape( ...@@ -1879,6 +1991,22 @@ extern "C" migraphx_status migraphx_experimental_custom_op_set_compute_shape(
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status
migraphx_experimental_custom_op_set_output_alias(migraphx_experimental_custom_op_t obj,
migraphx_experimental_custom_op_output_alias input)
{
auto api_error_result = migraphx::try_([&] { (obj)->output_alias_f = (input); });
return api_error_result;
}
extern "C" migraphx_status migraphx_experimental_custom_op_set_runs_on_offload_target(
migraphx_experimental_custom_op_t obj,
migraphx_experimental_custom_op_runs_on_offload_target input)
{
auto api_error_result = migraphx::try_([&] { (obj)->runs_on_offload_target_f = (input); });
return api_error_result;
}
extern "C" migraphx_status extern "C" migraphx_status
migraphx_experimental_custom_op_register(migraphx_experimental_custom_op_t experimental_custom_op) migraphx_experimental_custom_op_register(migraphx_experimental_custom_op_t experimental_custom_op)
{ {
......
...@@ -26,7 +26,6 @@ ...@@ -26,7 +26,6 @@
#include <stdlib.h> #include <stdlib.h>
#include <stdbool.h> #include <stdbool.h>
// Add new types here // Add new types here
// clang-format off // clang-format off
#define MIGRAPHX_SHAPE_VISIT_TYPES(m) \ #define MIGRAPHX_SHAPE_VISIT_TYPES(m) \
...@@ -144,6 +143,16 @@ typedef migraphx_status (*migraphx_experimental_custom_op_compute_shape)(migraph ...@@ -144,6 +143,16 @@ typedef migraphx_status (*migraphx_experimental_custom_op_compute_shape)(migraph
size_t exception_msg_size, size_t exception_msg_size,
migraphx_shapes_t inputs); migraphx_shapes_t inputs);
typedef migraphx_status (*migraphx_experimental_custom_op_output_alias)(size_t* out,
size_t* out_size,
void* obj,
char* exception_msg,
size_t exception_msg_size,
migraphx_shapes_t inputs);
typedef migraphx_status (*migraphx_experimental_custom_op_runs_on_offload_target)(
bool* out, void* obj, char* exception_msg, size_t exception_msg_size);
typedef migraphx_status (*migraphx_experimental_custom_op_copy)(void** out, void* input); typedef migraphx_status (*migraphx_experimental_custom_op_copy)(void** out, void* input);
typedef migraphx_status (*migraphx_experimental_custom_op_delete)(void* input); typedef migraphx_status (*migraphx_experimental_custom_op_delete)(void* input);
...@@ -175,6 +184,8 @@ migraphx_shape_strides(const size_t** out, size_t* out_size, const_migraphx_shap ...@@ -175,6 +184,8 @@ migraphx_shape_strides(const size_t** out, size_t* out_size, const_migraphx_shap
migraphx_status migraphx_shape_type(migraphx_shape_datatype_t* out, const_migraphx_shape_t shape); migraphx_status migraphx_shape_type(migraphx_shape_datatype_t* out, const_migraphx_shape_t shape);
migraphx_status migraphx_shape_elements(size_t* out, const_migraphx_shape_t shape);
migraphx_status migraphx_shape_bytes(size_t* out, const_migraphx_shape_t shape); migraphx_status migraphx_shape_bytes(size_t* out, const_migraphx_shape_t shape);
migraphx_status migraphx_status
...@@ -182,6 +193,8 @@ migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_sha ...@@ -182,6 +193,8 @@ migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_sha
migraphx_status migraphx_shape_standard(bool* out, const_migraphx_shape_t shape); migraphx_status migraphx_shape_standard(bool* out, const_migraphx_shape_t shape);
migraphx_status migraphx_shape_index(size_t* out, const_migraphx_shape_t shape, size_t i);
migraphx_status migraphx_argument_destroy(migraphx_argument_t argument); migraphx_status migraphx_argument_destroy(migraphx_argument_t argument);
migraphx_status migraphx_argument_assign_to(migraphx_argument_t output, migraphx_status migraphx_argument_assign_to(migraphx_argument_t output,
...@@ -344,6 +357,12 @@ migraphx_status migraphx_program_run(migraphx_arguments_t* out, ...@@ -344,6 +357,12 @@ migraphx_status migraphx_program_run(migraphx_arguments_t* out,
migraphx_program_t program, migraphx_program_t program,
migraphx_program_parameters_t params); migraphx_program_parameters_t params);
migraphx_status migraphx_program_run_async(migraphx_arguments_t* out,
migraphx_program_t program,
migraphx_program_parameters_t params,
void* s,
const char* name);
migraphx_status migraphx_status
migraphx_program_equal(bool* out, const_migraphx_program_t program, const_migraphx_program_t x); migraphx_program_equal(bool* out, const_migraphx_program_t program, const_migraphx_program_t x);
...@@ -502,6 +521,13 @@ migraphx_experimental_custom_op_set_compute(migraphx_experimental_custom_op_t ob ...@@ -502,6 +521,13 @@ migraphx_experimental_custom_op_set_compute(migraphx_experimental_custom_op_t ob
migraphx_status migraphx_experimental_custom_op_set_compute_shape( migraphx_status migraphx_experimental_custom_op_set_compute_shape(
migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_compute_shape input); migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_compute_shape input);
migraphx_status migraphx_experimental_custom_op_set_output_alias(
migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_output_alias input);
migraphx_status migraphx_experimental_custom_op_set_runs_on_offload_target(
migraphx_experimental_custom_op_t obj,
migraphx_experimental_custom_op_runs_on_offload_target input);
migraphx_status migraphx_status
migraphx_experimental_custom_op_register(migraphx_experimental_custom_op_t experimental_custom_op); migraphx_experimental_custom_op_register(migraphx_experimental_custom_op_t experimental_custom_op);
......
...@@ -25,10 +25,12 @@ ...@@ -25,10 +25,12 @@
#define MIGRAPHX_GUARD_API_RTGLIB_MIGRAPHX_HPP #define MIGRAPHX_GUARD_API_RTGLIB_MIGRAPHX_HPP
#include "migraphx.h" #include "migraphx.h"
#include <algorithm>
#include <cstring> #include <cstring>
#include <initializer_list> #include <initializer_list>
#include <migraphx/migraphx.h> #include <migraphx/migraphx.h>
#include <memory> #include <memory>
#include <numeric>
#include <exception> #include <exception>
#include <vector> #include <vector>
#include <cassert> #include <cassert>
...@@ -340,6 +342,11 @@ struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>> ...@@ -340,6 +342,11 @@ struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>>
this->set_handle(p, std::move(lifetime)); \ this->set_handle(p, std::move(lifetime)); \
} }
template <size_t N>
struct out_params
{
};
template <class Base> template <class Base>
struct interface_base : Base struct interface_base : Base
{ {
...@@ -391,7 +398,22 @@ struct interface_base : Base ...@@ -391,7 +398,22 @@ struct interface_base : Base
} }
template <class T, class Setter, class F> template <class T, class Setter, class F>
void set_fp(Setter setter, F pf) void set_fp(Setter setter, F pf, out_params<2>)
{
static F f = pf;
(void)f; // avoid warning on gcc
call(setter,
this->get_handle_ptr(),
[](auto out1, auto out2, void* obj, char* ex_msg, size_t ex_msg_size, auto... xs)
-> migraphx_status {
return try_([&] { call_cast_arg<T>(rank<2>{}, f, out1, out2, obj, xs...); },
ex_msg,
ex_msg_size);
});
}
template <class T, class Setter, class F>
void set_fp(Setter setter, F pf, out_params<1>)
{ {
static F f = pf; static F f = pf;
(void)f; // avoid warning on gcc (void)f; // avoid warning on gcc
...@@ -405,11 +427,27 @@ struct interface_base : Base ...@@ -405,11 +427,27 @@ struct interface_base : Base
} }
template <class T, class Setter, class F> template <class T, class Setter, class F>
void set_auto_fp(Setter setter, F f) void set_fp(Setter setter, F pf, out_params<0>)
{ {
return set_fp<T>(setter, [=](T& obj, auto out, auto... xs) { static F f = pf;
auto_invoke(f, out, obj, auto_convert_param(rank<2>{}, xs)...); (void)f; // avoid warning on gcc
}); call(setter,
this->get_handle_ptr(),
[](void* obj, char* ex_msg, size_t ex_msg_size, auto... xs) -> migraphx_status {
return try_(
[&] { call_cast_arg<T>(rank<0>{}, f, obj, xs...); }, ex_msg, ex_msg_size);
});
}
template <class T, class Setter, class F, class Out>
void set_auto_fp(Setter setter, F f, Out nums)
{
return set_fp<T>(
setter,
[=](T& obj, auto out1, auto out2, auto... xs) {
auto_invoke(f, out1, out2, obj, auto_convert_param(rank<2>{}, xs)...);
},
nums);
} }
struct no_out_arg struct no_out_arg
...@@ -419,7 +457,7 @@ struct interface_base : Base ...@@ -419,7 +457,7 @@ struct interface_base : Base
template <class T, class F, class X, class... Xs, class = std::enable_if_t<std::is_void<X>{}>> template <class T, class F, class X, class... Xs, class = std::enable_if_t<std::is_void<X>{}>>
static void call_cast_arg(rank<0>, F f, X* obj, Xs... xs) static void call_cast_arg(rank<0>, F f, X* obj, Xs... xs)
{ {
f(reinterpret_cast<T*>(obj), no_out_arg{}, xs...); f(reinterpret_cast<T*>(obj), no_out_arg{}, no_out_arg{}, xs...);
} }
template <class T, template <class T,
...@@ -430,17 +468,35 @@ struct interface_base : Base ...@@ -430,17 +468,35 @@ struct interface_base : Base
class = std::enable_if_t<std::is_void<X>{}>> class = std::enable_if_t<std::is_void<X>{}>>
static void call_cast_arg(rank<1>, F f, R result, X* obj, Xs... xs) static void call_cast_arg(rank<1>, F f, R result, X* obj, Xs... xs)
{ {
f(*reinterpret_cast<T*>(obj), result, xs...); f(*reinterpret_cast<T*>(obj), result, no_out_arg{}, xs...);
}
template <class T,
class F,
class R1,
class R2,
class X,
class... Xs,
class = std::enable_if_t<std::is_void<X>{}>>
static void call_cast_arg(rank<2>, F f, R1 result1, R2 result2, X* obj, Xs... xs)
{
f(*reinterpret_cast<T*>(obj), result1, result2, xs...);
}
template <class F, class T1, class T2, class... Ts>
void auto_invoke(F f, T1* out1, T2* out2, Ts&&... xs)
{
auto_assign(rank<2>{}, out1, out2, f(std::forward<Ts>(xs)...));
} }
template <class F, class T, class... Ts> template <class F, class T, class... Ts>
void auto_invoke(F f, T* out, Ts&&... xs) void auto_invoke(F f, T* out, no_out_arg, Ts&&... xs)
{ {
auto_assign(rank<2>{}, out, f(std::forward<Ts>(xs)...)); auto_assign(rank<1>{}, out, f(std::forward<Ts>(xs)...));
} }
template <class F, class T, class... Ts> template <class F, class T, class... Ts>
void auto_invoke(F f, no_out_arg, Ts&&... xs) void auto_invoke(F f, no_out_arg, no_out_arg, Ts&&... xs)
{ {
f(std::forward<Ts>(xs)...); f(std::forward<Ts>(xs)...);
} }
...@@ -469,7 +525,7 @@ struct interface_base : Base ...@@ -469,7 +525,7 @@ struct interface_base : Base
template <class T, class U> template <class T, class U>
void auto_assign(rank<0>, T* out, U x) void auto_assign(rank<0>, T* out, U x)
{ {
return *out = x; *out = x;
} }
template <class T, class U> template <class T, class U>
...@@ -477,12 +533,21 @@ struct interface_base : Base ...@@ -477,12 +533,21 @@ struct interface_base : Base
{ {
x.assign_to_handle(out); x.assign_to_handle(out);
} }
template <class T1, class T2, class U, class = std::enable_if_t<std::is_same<T2, size_t>{}>>
auto auto_assign(rank<2>, T1* out_ptr, T2* out_size, U x)
{
*out_size = std::min(*out_size, x.size());
std::copy_n(x.begin(), *out_size, out_ptr);
}
}; };
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_INTERFACE_LIFT(T, prefix, name) \ #define MIGRAPHX_INTERFACE_LIFT(n_out, T, prefix, name) \
this->set_auto_fp<T>(&migraphx_##prefix##_set_##name, \ this->set_auto_fp<T>( \
[](T& x, auto... xs) { return x.name(xs...); }) &migraphx_##prefix##_set_##name, \
[](T& x, auto... xs) { return x.name(xs...); }, \
out_params<n_out>{})
template <class Base, class T> template <class Base, class T>
using require_interface = using require_interface =
...@@ -517,7 +582,7 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape) ...@@ -517,7 +582,7 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.") MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.")
shape(const migraphx_shape* p) { this->set_handle(p, borrow{}); } shape(const migraphx_shape* p) { this->set_handle(p, borrow{}); }
MIGRAPHX_HANDLE_CONSTRUCTOR(shape); MIGRAPHX_HANDLE_CONSTRUCTOR(shape)
/// Construct a scalar shape /// Construct a scalar shape
shape(migraphx_shape_datatype_t type) shape(migraphx_shape_datatype_t type)
...@@ -567,6 +632,13 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape) ...@@ -567,6 +632,13 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
return pout; return pout;
} }
size_t elements() const
{
size_t pout;
call(&migraphx_shape_elements, &pout, this->get_handle_ptr());
return pout;
}
size_t bytes() const size_t bytes() const
{ {
size_t pout; size_t pout;
...@@ -581,6 +653,14 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape) ...@@ -581,6 +653,14 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
return result; return result;
} }
// map element index to space index
size_t index(size_t i) const
{
size_t result;
call(&migraphx_shape_index, &result, this->get_handle_ptr(), i);
return result;
}
friend bool operator==(const shape& px, const shape& py) friend bool operator==(const shape& px, const shape& py)
{ {
bool pout; bool pout;
...@@ -588,7 +668,7 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape) ...@@ -588,7 +668,7 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
return pout; return pout;
} }
friend bool operator!=(const shape& px, const shape& py) { return !(px == py); } friend bool operator!=(const shape& px, const shape& py) { return not(px == py); }
}; };
/** /**
...@@ -601,7 +681,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument) ...@@ -601,7 +681,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
{ {
argument() {} argument() {}
MIGRAPHX_HANDLE_CONSTRUCTOR(argument); MIGRAPHX_HANDLE_CONSTRUCTOR(argument)
MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.") MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.")
argument(const migraphx_argument* p) { this->set_handle(p, borrow{}); } argument(const migraphx_argument* p) { this->set_handle(p, borrow{}); }
...@@ -628,9 +708,15 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument) ...@@ -628,9 +708,15 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
template <typename T> template <typename T>
std::vector<T> as_vector() const std::vector<T> as_vector() const
{ {
size_t vector_len = this->get_shape().bytes() / sizeof(T); auto ss = this->get_shape();
T* buffer_ptr = reinterpret_cast<T*>(this->data()); auto num_elements = ss.elements();
return {buffer_ptr, buffer_ptr + vector_len}; std::vector<T> res(num_elements);
T* buffer_ptr = reinterpret_cast<T*>(this->data());
for(size_t i = 0; i < num_elements; i++)
{
res[i] = buffer_ptr[ss.index(i)];
}
return res;
} }
/// Generate an argument using random data /// Generate an argument using random data
...@@ -647,7 +733,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument) ...@@ -647,7 +733,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
return pout; return pout;
} }
friend bool operator!=(const argument& px, const argument& py) { return !(px == py); } friend bool operator!=(const argument& px, const argument& py) { return not(px == py); }
}; };
/// A target for compilation /// A target for compilation
...@@ -655,7 +741,7 @@ struct target : MIGRAPHX_HANDLE_BASE(target) ...@@ -655,7 +741,7 @@ struct target : MIGRAPHX_HANDLE_BASE(target)
{ {
target() {} target() {}
MIGRAPHX_HANDLE_CONSTRUCTOR(target); MIGRAPHX_HANDLE_CONSTRUCTOR(target)
/// Construct a target from its name /// Construct a target from its name
target(const char* name) { this->make_handle(&migraphx_target_create, name); } target(const char* name) { this->make_handle(&migraphx_target_create, name); }
...@@ -665,7 +751,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes) ...@@ -665,7 +751,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
{ {
program_parameter_shapes() {} program_parameter_shapes() {}
MIGRAPHX_HANDLE_CONSTRUCTOR(program_parameter_shapes); MIGRAPHX_HANDLE_CONSTRUCTOR(program_parameter_shapes)
size_t size() const size_t size() const
{ {
...@@ -684,7 +770,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes) ...@@ -684,7 +770,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
std::vector<const char*> names() const std::vector<const char*> names() const
{ {
std::vector<const char*> result(this->size()); std::vector<const char*> result(this->size());
if(!result.empty()) if(not result.empty())
{ {
call(&migraphx_program_parameter_shapes_names, result.data(), this->get_handle_ptr()); call(&migraphx_program_parameter_shapes_names, result.data(), this->get_handle_ptr());
} }
...@@ -695,7 +781,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes) ...@@ -695,7 +781,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
/// A class to construct the inputs parameters for a program /// A class to construct the inputs parameters for a program
struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters) struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters)
{ {
MIGRAPHX_HANDLE_CONSTRUCTOR(program_parameters); MIGRAPHX_HANDLE_CONSTRUCTOR(program_parameters)
MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.") MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.")
program_parameters(migraphx_program_parameters* p) { this->set_handle(p, borrow{}); } program_parameters(migraphx_program_parameters* p) { this->set_handle(p, borrow{}); }
...@@ -722,7 +808,7 @@ struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters) ...@@ -722,7 +808,7 @@ struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters)
struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments> struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments>
{ {
MIGRAPHX_HANDLE_CONSTRUCTOR(arguments); MIGRAPHX_HANDLE_CONSTRUCTOR(arguments)
size_t size() const size_t size() const
{ {
...@@ -741,7 +827,7 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments> ...@@ -741,7 +827,7 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments>
struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes> struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
{ {
MIGRAPHX_HANDLE_CONSTRUCTOR(shapes); MIGRAPHX_HANDLE_CONSTRUCTOR(shapes)
size_t size() const size_t size() const
{ {
...@@ -760,7 +846,7 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes> ...@@ -760,7 +846,7 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
struct operation : MIGRAPHX_HANDLE_BASE(operation) struct operation : MIGRAPHX_HANDLE_BASE(operation)
{ {
MIGRAPHX_HANDLE_CONSTRUCTOR(operation); MIGRAPHX_HANDLE_CONSTRUCTOR(operation)
template <class... Ts> template <class... Ts>
operation(const char* name, const char* attributes = nullptr, Ts... xs) operation(const char* name, const char* attributes = nullptr, Ts... xs)
...@@ -778,12 +864,12 @@ struct operation : MIGRAPHX_HANDLE_BASE(operation) ...@@ -778,12 +864,12 @@ struct operation : MIGRAPHX_HANDLE_BASE(operation)
struct instruction : MIGRAPHX_CONST_HANDLE_BASE(instruction) struct instruction : MIGRAPHX_CONST_HANDLE_BASE(instruction)
{ {
MIGRAPHX_HANDLE_CONSTRUCTOR(instruction); MIGRAPHX_HANDLE_CONSTRUCTOR(instruction)
}; };
struct instructions : MIGRAPHX_HANDLE_BASE(instructions) struct instructions : MIGRAPHX_HANDLE_BASE(instructions)
{ {
MIGRAPHX_HANDLE_CONSTRUCTOR(instructions); MIGRAPHX_HANDLE_CONSTRUCTOR(instructions)
template <class... Ts> template <class... Ts>
instructions(Ts... xs) instructions(Ts... xs)
...@@ -797,7 +883,7 @@ struct module; ...@@ -797,7 +883,7 @@ struct module;
struct modules : MIGRAPHX_HANDLE_BASE(modules) struct modules : MIGRAPHX_HANDLE_BASE(modules)
{ {
MIGRAPHX_HANDLE_CONSTRUCTOR(modules); MIGRAPHX_HANDLE_CONSTRUCTOR(modules)
template <class... Ts> template <class... Ts>
modules(Ts... xs) modules(Ts... xs)
...@@ -911,7 +997,7 @@ struct compile_options : MIGRAPHX_HANDLE_BASE(compile_options) ...@@ -911,7 +997,7 @@ struct compile_options : MIGRAPHX_HANDLE_BASE(compile_options)
{ {
compile_options() { this->make_handle(&migraphx_compile_options_create); } compile_options() { this->make_handle(&migraphx_compile_options_create); }
MIGRAPHX_HANDLE_CONSTRUCTOR(compile_options); MIGRAPHX_HANDLE_CONSTRUCTOR(compile_options)
/// For targets with offloaded memory(such as the gpu), this will insert /// For targets with offloaded memory(such as the gpu), this will insert
/// instructions during compilation to copy the input parameters to the /// instructions during compilation to copy the input parameters to the
...@@ -935,7 +1021,7 @@ struct program : MIGRAPHX_HANDLE_BASE(program) ...@@ -935,7 +1021,7 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
{ {
program() { this->make_handle(&migraphx_program_create); } program() { this->make_handle(&migraphx_program_create); }
MIGRAPHX_HANDLE_CONSTRUCTOR(program); MIGRAPHX_HANDLE_CONSTRUCTOR(program)
/// Compile the program for a specific target to be ran on /// Compile the program for a specific target to be ran on
void compile(const target& ptarget, const compile_options& poptions) const void compile(const target& ptarget, const compile_options& poptions) const
...@@ -979,6 +1065,20 @@ struct program : MIGRAPHX_HANDLE_BASE(program) ...@@ -979,6 +1065,20 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
return arguments(pout, own{}); return arguments(pout, own{});
} }
template <class Stream>
/// Overloaded to allow for execution_environment input
arguments run_async(const program_parameters& pparams, Stream* s) const
{
migraphx_arguments_t pout;
call(&migraphx_program_run_async,
&pout,
this->get_handle_ptr(),
pparams.get_handle_ptr(),
s,
get_type_name<Stream>().c_str());
return arguments(pout, own{});
}
void print() const { call(&migraphx_program_print, this->get_handle_ptr()); } void print() const { call(&migraphx_program_print, this->get_handle_ptr()); }
program sort() program sort()
...@@ -1015,13 +1115,13 @@ struct program : MIGRAPHX_HANDLE_BASE(program) ...@@ -1015,13 +1115,13 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
return module{p_modu, this->share_handle()}; return module{p_modu, this->share_handle()};
} }
friend bool operator!=(const program& px, const program& py) { return !(px == py); } friend bool operator!=(const program& px, const program& py) { return not(px == py); }
}; };
// options for migraphx file format options // options for migraphx file format options
struct file_options : MIGRAPHX_HANDLE_BASE(file_options) struct file_options : MIGRAPHX_HANDLE_BASE(file_options)
{ {
MIGRAPHX_HANDLE_CONSTRUCTOR(file_options); MIGRAPHX_HANDLE_CONSTRUCTOR(file_options)
file_options() { this->make_handle(&migraphx_file_options_create); } file_options() { this->make_handle(&migraphx_file_options_create); }
// set file format // set file format
...@@ -1063,7 +1163,7 @@ struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options) ...@@ -1063,7 +1163,7 @@ struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options)
{ {
onnx_options() { this->make_handle(&migraphx_onnx_options_create); } onnx_options() { this->make_handle(&migraphx_onnx_options_create); }
MIGRAPHX_HANDLE_CONSTRUCTOR(onnx_options); MIGRAPHX_HANDLE_CONSTRUCTOR(onnx_options)
/// Make onnx parser treat an inputs with a certain dimensions /// Make onnx parser treat an inputs with a certain dimensions
void set_input_parameter_shape(const std::string& name, std::vector<std::size_t> dim) void set_input_parameter_shape(const std::string& name, std::vector<std::size_t> dim)
...@@ -1145,7 +1245,7 @@ struct tf_options : MIGRAPHX_HANDLE_BASE(tf_options) ...@@ -1145,7 +1245,7 @@ struct tf_options : MIGRAPHX_HANDLE_BASE(tf_options)
{ {
tf_options() { this->make_handle(&migraphx_tf_options_create); } tf_options() { this->make_handle(&migraphx_tf_options_create); }
MIGRAPHX_HANDLE_CONSTRUCTOR(tf_options); MIGRAPHX_HANDLE_CONSTRUCTOR(tf_options)
/// Make tf parser treat an inputs with a certain dimensions /// Make tf parser treat an inputs with a certain dimensions
void set_input_parameter_shape(const std::string& name, std::vector<std::size_t> dim) void set_input_parameter_shape(const std::string& name, std::vector<std::size_t> dim)
...@@ -1198,7 +1298,7 @@ struct quantize_op_names : MIGRAPHX_HANDLE_BASE(quantize_op_names) ...@@ -1198,7 +1298,7 @@ struct quantize_op_names : MIGRAPHX_HANDLE_BASE(quantize_op_names)
{ {
quantize_op_names() { this->make_handle(&migraphx_quantize_op_names_create); } quantize_op_names() { this->make_handle(&migraphx_quantize_op_names_create); }
MIGRAPHX_HANDLE_CONSTRUCTOR(quantize_op_names); MIGRAPHX_HANDLE_CONSTRUCTOR(quantize_op_names)
void add(const std::string& name) void add(const std::string& name)
{ {
...@@ -1223,7 +1323,7 @@ struct quantize_int8_options : MIGRAPHX_HANDLE_BASE(quantize_int8_options) ...@@ -1223,7 +1323,7 @@ struct quantize_int8_options : MIGRAPHX_HANDLE_BASE(quantize_int8_options)
{ {
quantize_int8_options() { this->make_handle(&migraphx_quantize_int8_options_create); } quantize_int8_options() { this->make_handle(&migraphx_quantize_int8_options_create); }
MIGRAPHX_HANDLE_CONSTRUCTOR(quantize_int8_options); MIGRAPHX_HANDLE_CONSTRUCTOR(quantize_int8_options)
/// Add an operator that should be quantized /// Add an operator that should be quantized
void add_op_name(const std::string& name) void add_op_name(const std::string& name)
...@@ -1255,7 +1355,10 @@ struct experimental_custom_op_base ...@@ -1255,7 +1355,10 @@ struct experimental_custom_op_base
virtual std::string name() const = 0; virtual std::string name() const = 0;
virtual argument compute(context ctx, shape output, arguments inputs) const = 0; virtual argument compute(context ctx, shape output, arguments inputs) const = 0;
virtual shape compute_shape(shapes inputs) const = 0; virtual shape compute_shape(shapes inputs) const = 0;
virtual ~experimental_custom_op_base() = default; virtual std::vector<size_t> output_alias(shapes) const { return {}; }
// TODO: Return target string instead of bool
virtual bool runs_on_offload_target() const = 0;
virtual ~experimental_custom_op_base() = default;
}; };
struct experimental_custom_op : interface_base<MIGRAPHX_HANDLE_BASE(experimental_custom_op)> struct experimental_custom_op : interface_base<MIGRAPHX_HANDLE_BASE(experimental_custom_op)>
...@@ -1267,8 +1370,10 @@ struct experimental_custom_op : interface_base<MIGRAPHX_HANDLE_BASE(experimental ...@@ -1267,8 +1370,10 @@ struct experimental_custom_op : interface_base<MIGRAPHX_HANDLE_BASE(experimental
obj, obj,
get_type_name(obj).c_str(), get_type_name(obj).c_str(),
obj.name().c_str()); obj.name().c_str());
MIGRAPHX_INTERFACE_LIFT(T, experimental_custom_op, compute_shape); MIGRAPHX_INTERFACE_LIFT(1, T, experimental_custom_op, compute_shape);
MIGRAPHX_INTERFACE_LIFT(T, experimental_custom_op, compute); MIGRAPHX_INTERFACE_LIFT(1, T, experimental_custom_op, compute);
MIGRAPHX_INTERFACE_LIFT(2, T, experimental_custom_op, output_alias);
MIGRAPHX_INTERFACE_LIFT(1, T, experimental_custom_op, runs_on_offload_target);
} }
void register_op() { call(&migraphx_experimental_custom_op_register, this->get_handle_ptr()); } void register_op() { call(&migraphx_experimental_custom_op_register, this->get_handle_ptr()); }
......
...@@ -115,6 +115,7 @@ def shape(h): ...@@ -115,6 +115,7 @@ def shape(h):
const=True) const=True)
h.method('strides', returns='const std::vector<size_t>&', const=True) h.method('strides', returns='const std::vector<size_t>&', const=True)
h.method('type', returns='migraphx::shape::type_t', const=True) h.method('type', returns='migraphx::shape::type_t', const=True)
h.method('elements', returns='size_t', const=True)
h.method('bytes', returns='size_t', const=True) h.method('bytes', returns='size_t', const=True)
h.method('equal', h.method('equal',
api.params(x='const migraphx::shape&'), api.params(x='const migraphx::shape&'),
...@@ -122,6 +123,7 @@ def shape(h): ...@@ -122,6 +123,7 @@ def shape(h):
returns='bool', returns='bool',
const=True) const=True)
h.method('standard', returns='bool', const=True) h.method('standard', returns='bool', const=True)
h.method('index', api.params(i='size_t'), returns='size_t', const=True)
@auto_handle() @auto_handle()
...@@ -274,6 +276,13 @@ def program(h): ...@@ -274,6 +276,13 @@ def program(h):
params='std::unordered_map<std::string, migraphx::argument>'), params='std::unordered_map<std::string, migraphx::argument>'),
invoke='migraphx::run($@)', invoke='migraphx::run($@)',
returns='std::vector<migraphx::argument>') returns='std::vector<migraphx::argument>')
h.method('run_async',
api.params(
params='std::unordered_map<std::string, migraphx::argument>',
s='void*',
name='const char *'),
invoke='migraphx::run_async($@)',
returns='std::vector<migraphx::argument>')
h.method('equal', h.method('equal',
api.params(x='const migraphx::program&'), api.params(x='const migraphx::program&'),
invoke='migraphx::equal($@)', invoke='migraphx::equal($@)',
...@@ -450,4 +459,8 @@ def experimental_custom_op(h): ...@@ -450,4 +459,8 @@ def experimental_custom_op(h):
h.virtual('compute_shape', h.virtual('compute_shape',
api.params(inputs='std::vector<migraphx::shape>'), api.params(inputs='std::vector<migraphx::shape>'),
returns='migraphx::shape') returns='migraphx::shape')
h.virtual('output_alias',
api.params(inputs='std::vector<migraphx::shape>'),
returns='std::vector<size_t>')
h.virtual('runs_on_offload_target', returns='bool')
h.method('register', invoke='migraphx::register_custom_op($@)') h.method('register', invoke='migraphx::register_custom_op($@)')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment