Commit 76547728 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into layernorm_half2

parents 48b39e06 0295965d
...@@ -148,13 +148,15 @@ jobs: ...@@ -148,13 +148,15 @@ jobs:
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v2 uses: actions/setup-python@v2
with: with:
python-version: 3.6 python-version: 3.8
- name: Install pyflakes - name: Install pyflakes
run: pip install pyflakes==2.3.1 mypy==0.931 run: pip install pyflakes==2.4.0 mypy==0.931
- name: Run pyflakes - name: Run pyflakes
run: | run: |
pyflakes --version
pyflakes examples/ tools/ src/ test/ doc/ pyflakes examples/ tools/ src/ test/ doc/
mypy --version
mypy tools/api.py mypy tools/api.py
linux: linux:
......
...@@ -20,7 +20,7 @@ def rocmtestnode(Map conf) { ...@@ -20,7 +20,7 @@ def rocmtestnode(Map conf) {
rm -rf build rm -rf build
mkdir build mkdir build
cd build cd build
CXX=${compiler} CXXFLAGS='-Werror' cmake -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache ${flags} .. CXX=${compiler} CXXFLAGS='-Werror' cmake -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache ${flags} ..
make -j\$(nproc) generate all doc package check VERBOSE=1 make -j\$(nproc) generate all doc package check VERBOSE=1
""" """
echo cmd echo cmd
...@@ -112,17 +112,17 @@ rocmtest clang_debug: rocmnode('vega') { cmake_build -> ...@@ -112,17 +112,17 @@ rocmtest clang_debug: rocmnode('vega') { cmake_build ->
def debug_flags = "-g -O2 -fno-omit-frame-pointer -fsanitize=${sanitizers} -fno-sanitize-recover=${sanitizers}" def debug_flags = "-g -O2 -fno-omit-frame-pointer -fsanitize=${sanitizers} -fno-sanitize-recover=${sanitizers}"
cmake_build("/opt/rocm/llvm/bin/clang++", "-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_PYTHON=Off -DMIGRAPHX_ENABLE_GPU=Off -DMIGRAPHX_ENABLE_CPU=On -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}'") cmake_build("/opt/rocm/llvm/bin/clang++", "-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_PYTHON=Off -DMIGRAPHX_ENABLE_GPU=Off -DMIGRAPHX_ENABLE_CPU=On -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}'")
} }
}, clang_release_navi: rocmnode('navi21') { cmake_build -> }//, clang_release_navi: rocmnode('navi21') { cmake_build ->
stage('HIP Clang Release Navi') { // stage('HIP Clang Release Navi') {
cmake_build("/opt/rocm/llvm/bin/clang++", "-DCMAKE_BUILD_TYPE=release") // cmake_build("/opt/rocm/llvm/bin/clang++", "-DCMAKE_BUILD_TYPE=release")
} // }
} //}
def onnxnode(name, body) { def onnxnode(name, body) {
return { label -> return { label ->
rocmtestnode(variant: label, node: rocmnodename(name), docker_args: '-u root', body: body, pre: { rocmtestnode(variant: label, node: rocmnodename(name), docker_args: '-u root', body: body, pre: {
sh 'rm -rf ./build/*.deb' sh 'rm -rf ./build/*.deb'
unstash 'migraphx-package' unstash 'migraphx-package'
}) })
} }
} }
......
project(migraphx-doc) project(migraphx-doc)
find_package(ROCM REQUIRED) find_package(ROCM REQUIRED)
include(ROCMDoxygenDoc) include(ROCMDoxygenDoc)
set(DOXYGEN_OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/doxygen/) set(DOXYGEN_OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/doxygen)
rocm_add_doxygen_doc( rocm_add_doxygen_doc(
OUTPUT_DIRECTORY ${DOXYGEN_OUTPUT} OUTPUT_DIRECTORY ${DOXYGEN_OUTPUT}
INPUT INPUT
${PROJECT_SOURCE_DIR}/src ${CMAKE_SOURCE_DIR}/src
INCLUDE_PATH INCLUDE_PATH
${PROJECT_SOURCE_DIR}/src/include ${CMAKE_SOURCE_DIR}/src/include
${PROJECT_SOURCE_DIR}/src/targets/cpu/include ${CMAKE_SOURCE_DIR}/src/targets/cpu/include
${PROJECT_SOURCE_DIR}/src/targets/gpu/include ${CMAKE_SOURCE_DIR}/src/targets/gpu/include
STRIP_FROM_INC_PATH STRIP_FROM_INC_PATH
${PROJECT_SOURCE_DIR}/src/include ${CMAKE_SOURCE_DIR}/src/include
${PROJECT_SOURCE_DIR}/src/targets/cpu/include ${CMAKE_SOURCE_DIR}/src/targets/cpu/include
${PROJECT_SOURCE_DIR}/src/targets/gpu/include ${CMAKE_SOURCE_DIR}/src/targets/gpu/include
EXCLUDE_PATTERNS EXCLUDE_PATTERNS
${PROJECT_SOURCE_DIR}/src/targets/gpu/kernels ${CMAKE_SOURCE_DIR}/src/targets/gpu/kernels
${PROJECT_SOURCE_DIR}/src/targets/gpu/device ${CMAKE_SOURCE_DIR}/src/targets/gpu/device
SEARCH_INCLUDES YES SEARCH_INCLUDES YES
MACRO_EXPANSION YES MACRO_EXPANSION YES
RECURSIVE YES RECURSIVE YES
...@@ -39,13 +38,14 @@ rocm_add_doxygen_doc( ...@@ -39,13 +38,14 @@ rocm_add_doxygen_doc(
EXTRACT_ALL YES EXTRACT_ALL YES
ENUM_VALUES_PER_LINE 1 ENUM_VALUES_PER_LINE 1
FULL_PATH_NAMES YES FULL_PATH_NAMES YES
WARN_LOGFILE "${DOXYGEN_OUTPUT}/DoxygenWarningLog.txt"
PREDEFINED DOXYGEN PREDEFINED DOXYGEN
) )
include(ROCMSphinxDoc) include(ROCMSphinxDoc)
rocm_add_sphinx_doc(src rocm_add_sphinx_doc(src
BUILDER html BUILDER html
OUTPUT_DIR html OUTPUT_DIR html
VARS VARS
breathe_projects.proj=${DOXYGEN_OUTPUT}/xml breathe_projects.proj=${DOXYGEN_OUTPUT}/xml
breathe_default_project=proj breathe_default_project=proj
...@@ -63,6 +63,6 @@ if(LATEX_FOUND) ...@@ -63,6 +63,6 @@ if(LATEX_FOUND)
DEPENDS doxygen DEPENDS doxygen
) )
else() else()
message("Latex builder not found. Latex builder is required only for building the PDF documentation for MIGraph and is not necessary for building the library, or any other components. To build PDF documentation run make in ${CMAKE_CURRENT_SOURCE_DIR}/pdf, once a latex builder is installed.") message("Latex builder not found. Latex builder is required only for building the PDF documentation for MIGraphX and is not necessary for building the library, or any other components. To build PDF documentation run make in ${CMAKE_CURRENT_SOURCE_DIR}/pdf, once a latex builder is installed.")
endif() endif()
Developer Guide Contributor Guide
=============== ===============
.. toctree:: .. toctree::
:maxdepth: 2 :maxdepth: 2
:caption: Contents: :caption: Contents:
overview dev_intro
dev/data dev/data
dev/operators dev/operators
dev/program dev/program
......
MIGraphX Fundamentals
======================
MIGraphX provides an optimized execution engine for deep learning neural networks.
We will cover some simple operations in the MIGraphX framework here.
For a quick start guide to using MIGraphX, look in the examples directory: ``https://github.com/ROCmSoftwarePlatform/AMDMIGraphX/tree/develop/examples/migraphx``.
Location of the Examples
-------------------------
The ``ref_dev_examples.cpp`` can be found in the test directory (``/test``).
The executable file ``test_ref_dev_examples`` based on this file will be created in the ``bin/`` of the build directory after running ``make -j$(nproc) test_ref_dev_examples``.
The executable will also be created when running ``make -j$(nproc) check``, alongside with all the other tests.
Directions for building MIGraphX from source can be found in the main README file: ``https://github.com/ROCmSoftwarePlatform/AMDMIGraphX#readme``.
Adding Two Literals
--------------------
A program is a collection of modules, which are collections of instructions to be executed when calling `eval <migraphx::program::eval>`.
Each instruction has an associated `operation <migraphx::operation>` which represents the computation to be performed by the instruction.
We start with a snippet of the simple ``add_two_literals()`` function::
// create the program and get a pointer to the main module
migraphx::program p;
auto* mm = p.get_main_module();
// add two literals to the program
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
// make the add operation between the two literals and add it to the program
mm->add_instruction(migraphx::make_op("add"), one, two);
// compile the program on the reference device
p.compile(migraphx::ref::target{});
// evaulate the program and retreive the result
auto result = p.eval({}).back();
std::cout << "add_two_literals: 1 + 2 = " << result << "\n";
We start by creating a simple ``migraphx::program`` object and then getting a pointer to the main module of it.
The program is a collection of ``modules`` that start executing from the main module, so instructions are added to the modules rather than directly onto the program object.
We then use the `add_literal <migraphx::program::add_literal>` function to add an instruction that stores the literal number ``1`` while returning an `instruction_ref <migraphx::instruction_ref>`.
The returned `instruction_ref <migraphx::instruction_ref>` can be used in another instruction as an input.
We use the same `add_literal <migraphx::program::add_literal>` function to add a ``2`` to the program.
After creating the literals, we then create the instruction to add the numbers together.
This is done by using the `add_instruction <migraphx::program::add_instruction>` function with the ``"add"`` `operation <migraphx::program::operation>` created by `make_op <migraphx::program::make_op>` along with the previous `add_literal` `instruction_ref <migraphx::instruction_ref>` for the input arguments of the instruction.
Finally, we can run this `program <migraphx::program>` by compiling it for the reference target (CPU) and then running it with `eval <migraphx::program::eval>`
The result is then retreived and printed to the console.
We can compile the program for the GPU as well, but the file will have to be moved to the ``test/gpu/`` directory and the correct target must be included::
#include <migraphx/gpu/target.hpp>
Using Parameters
-----------------
The previous program will always produce the same value of adding ``1`` and ``2``.
In the next program we want to pass an input to a program and compute a value based on the input.
We can modify the program to take an input parameter ``x``, as seen in the ``add_parameter()`` function::
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int32_type, {1}};
// add a "x" parameter with the shape s
auto x = mm->add_parameter("x", s);
auto two = mm->add_literal(2);
// add the "add" instruction between the "x" parameter and "two" to the module
mm->add_instruction(migraphx::make_op("add"), x, two);
p.compile(migraphx::ref::target{});
This adds a parameter of type ``int32``, and compiles it for the CPU.
To run the program, we need to pass the parameter as a ``parameter_map`` when we call `eval <migraphx::program::eval>`.
We create the ``parameter_map`` by setting the ``x`` key to an `argument <migraphx::argument>` object with an ``int`` data type::
// create a parameter_map object for passing a value to the "x" parameter
std::vector<int> data = {4};
migraphx::parameter_map params;
params["x"] = migraphx::argument(s, data.data());
auto result = p.eval(params).back();
std::cout << "add_parameters: 4 + 2 = " << result << "\n";
EXPECT(result.at<int>() == 6);
Handling Tensor Data
---------------------
In the previous examples we have only been dealing with scalars, but the `shape <migraphx::shape>` class can describe multi-dimensional tensors.
For example, we can compute a simple convolution::
migraphx::program p;
auto* mm = p.get_main_module();
// create shape objects for the input tensor and weights
migraphx::shape input_shape{migraphx::shape::float_type, {2, 3, 4, 4}};
migraphx::shape weights_shape{migraphx::shape::float_type, {3, 3, 3, 3}};
// create the parameters and add the "convolution" operation to the module
auto input = mm->add_parameter("X", input_shape);
auto weights = mm->add_parameter("W", weights_shape);
mm->add_instruction(migraphx::make_op("convolution", {{"padding", {1, 1}}, {"stride", {2, 2}}}), input, weights);
Here we create two parameters for both the ``input`` and ``weights``.
In the previous examples, we created simple literals, however, most programs will take data from allocated buffers (usually on the GPU).
In this case, we can create `argument <migraphx::argument>` objects directly from the pointers to the buffers::
// Compile the program
p.compile(migraphx::ref::target{});
// Allocated buffers by the user
std::vector<float> a = ...;
std::vector<float> c = ...;
// Solution vector
std::vector<float> sol = ...;
// Create the arguments in a parameter_map
migraphx::parameter_map params;
params["X"] = migraphx::argument(input_shape, a.data());
params["W"] = migraphx::argument(weights_shape, c.data());
// Evaluate and confirm the result
auto result = p.eval(params).back();
std::vector<float> results_vector(64);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, sol));
An `argument <migraphx::argument>` can handle memory buffers from either the GPU or the CPU.
By default when running the `program <migraphx::program>`, buffers are allocated on the corresponding target.
When compiling for the CPU, the buffers by default will be allocated on the CPU.
When compiling for the GPU, the buffers by default will be allocated on the GPU.
With the option ``offloaf_copy=true`` set while compiling for the GPU, the buffers will be located on the CPU.
Importing From ONNX
--------------------
A `program <migraphx::program>` can be built directly from an onnx file using the MIGraphX ONNX parser.
This makes it easier to use neural networks directly from other frameworks.
In this case, there is an ``parse_onnx`` function::
program p = migraphx::parse_onnx("model.onnx");
p.compile(migraphx::gpu::target{});
...@@ -13,7 +13,7 @@ Welcome to AMD MIGraphX's documentation! ...@@ -13,7 +13,7 @@ Welcome to AMD MIGraphX's documentation!
py_user_guide py_user_guide
cpp_user_guide cpp_user_guide
driver driver
developer_guide contributor_guide
Indices and tables Indices and tables
......
Overview
========
MIGraphX provides an optimized execution engine for deep learning neural networks.
Building a program
------------------
A program consists of a set of instructions to be executed when calling `eval <migraphx::program::eval>`. Each instruction has an associated `operation <migraphx::operation>` which represents the computation to be performed by the instruction.
We can start by building a simple program to add two numbers together::
program p;
instruction_ref one = p.add_literal(1);
instruction_ref two = p.add_literal(2);
p.add_instruction(add{}, one, two);
The `add_literal <migraphx::program::add_literal>` function will add an instruction to the program to store a literal number. The `instruction_ref <migraphx::instruction_ref>` is a reference to the instruction in the program, which can be used to compose the output of the instruction with another instruction.
After creating the literals, we then create the instruction to add the numbers together. This is done by using the `add{} <migraphx::add>` operation class along with the `instruction_ref <migraphx::instruction_ref>` for the input arguments of the instruction.
Finally, we can run this `program <migraphx::program>` by compiling it for the cpu and then running it with `eval <migraphx::program::eval>`::
p.compile(cpu::target{});
argument result = p.eval({});
The easiest way to see the result is to print it::
std::cout << result;
Which will print ``3``.
We can also compile the program for the gpu as well.
Adding parameters
-----------------
Of course, this program will always produce the same value which is quite uninteresting. Instead, we want to pass an input to a program and compute a value based on the input. This can be done with a parameter. For example, we can modify the program to take an input ``x``::
program p;
instruction_ref x = p.add_parameter("x", {shape::int64_type});
instruction_ref two = p.add_literal(2);
p.add_instruction(add{}, x, two);
p.compile(cpu::target{});
This adds a parameter of type ``int64``, and compiles it for the ``cpu``. To run the program, we need to pass the parameter to it when we call `eval <migraphx::program::eval>`::
argument result = p.eval({
{"x", literal{1}.get_argument()}
});
std::cout << result;
This will print ``3``.
A parameter is given as an `argument <migraphx::argument>`. In this case, the simplest way of creating an `argument <migraphx::argument>` is from a `literal <migraphx::literal>`.
Tensor data
-----------
In this example we are just creating numbers, but the `shape <migraphx::shape>` class can describe multi-dimensional tensors. For example, we can build a simple network with convolution and relu::
program p;
instruction_ref input = p.add_parameter("x", shape{shape::float_type, {1, 3, 32, 32}});
instruction_ref weights = p.add_parameter("w", shape{shape::float_type, {1, 3, 5, 5}});
instruction_ref conv = p.add_instruction(convolution{}, input, weights);
p.add_instruction(activation{"relu"}, conv);
Here we create two parameters for both the ``input`` and ``weights``. In the previous examples, we just created simple literals, however, most programs will take data from already allocated buffers(usually on the GPU). In this case, we can create `argument <migraphx::argument>` objects directly from the pointers to the buffers::
// Compile the program
p.compile(gpu::target{});
// Allocated buffers by the user
float* input = ...;
float* weights = ...;
// Create the arguments
argument input_arg{shape{shape::float_type, {1, 3, 32, 32}}, input};
argument weights_arg{shape{shape::float_type, {1, 3, 32, 32}}, weights};
p.eval({{"x", input_arg}, {"w", weights_arg}})
An `argument <migraphx::argument>` can handle memory buffers from either the GPU or the CPU, but when running the `program <migraphx::program>`, buffers should be allocated for the corresponding target. That is, when compiling for the CPU, the buffers should be allocated on the CPU, and when compiling for the GPU the buffers should be allocated on the GPU.
Importing from onnx
-------------------
A `program <migraphx::program>` can be built directly from an onnx file, which makes it easier to use neural networks directly from other frameworks. In this case, there is an ``parse_onnx`` function::
program p = parse_onnx("model.onnx");
p.compile(gpu::target{});
...@@ -121,7 +121,7 @@ target ...@@ -121,7 +121,7 @@ target
Constructs the target. Constructs the target.
:param str name: The name of the target to construct. This can either be 'cpu' or 'gpu'. :param str name: The name of the target to construct. This can either be 'gpu' or 'ref'.
:rtype: target :rtype: target
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <migraphx/ref/target.hpp> #include <migraphx/ref/target.hpp>
#include <migraphx/load_save.hpp> #include <migraphx/load_save.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/json.hpp> #include <migraphx/json.hpp>
#include <migraphx/convert_to_json.hpp> #include <migraphx/convert_to_json.hpp>
#include <algorithm> #include <algorithm>
...@@ -212,6 +213,39 @@ void print_program(const program& p) { std::cout << p << std::endl; } ...@@ -212,6 +213,39 @@ void print_program(const program& p) { std::cout << p << std::endl; }
void print_module(const module& m) { std::cout << m << std::endl; } void print_module(const module& m) { std::cout << m << std::endl; }
struct experimental_custom_op
{
std::string name;
experimental_custom_op() = default;
experimental_custom_op(std::string pname) : name(std::move(pname)) {}
};
template <class CustomOp>
struct custom_operation
{
template <class Self, class F>
static auto reflect(Self&, F)
{
return pack();
}
CustomOp op;
std::string name() const { return op.xobject.name; }
shape compute_shape(std::vector<shape> inputs) const
{
return op.compute_shape(std::move(inputs));
}
argument compute(const std::vector<argument>&) const { MIGRAPHX_THROW("Not computable"); }
};
template <class CustomOp>
void register_custom_op(const CustomOp& op)
{
register_op(custom_operation<CustomOp>{op});
}
migraphx::context get_context(const program& p) { return p.get_context(); } migraphx::context get_context(const program& p) { return p.get_context(); }
} // namespace migraphx } // namespace migraphx
...@@ -238,12 +272,60 @@ void destroy(T* x) ...@@ -238,12 +272,60 @@ void destroy(T* x)
{ {
delete x; // NOLINT delete x; // NOLINT
} }
// TODO: Move to interface preamble
template <class C, class D>
struct manage_generic_ptr
{
manage_generic_ptr() = default;
manage_generic_ptr(std::nullptr_t) {}
manage_generic_ptr(void* pdata, C pcopier, D pdeleter)
: data(nullptr), copier(pcopier), deleter(pdeleter)
{
copier(&data, pdata);
}
manage_generic_ptr(const manage_generic_ptr& rhs)
: data(nullptr), copier(rhs.copier), deleter(rhs.deleter)
{
if(copier)
copier(&data, rhs.data);
}
manage_generic_ptr(manage_generic_ptr&& other) noexcept
: data(other.data), copier(other.copier), deleter(other.deleter)
{
other.data = nullptr;
other.copier = nullptr;
other.deleter = nullptr;
}
manage_generic_ptr& operator=(manage_generic_ptr rhs)
{
std::swap(data, rhs.data);
std::swap(copier, rhs.copier);
std::swap(deleter, rhs.deleter);
return *this;
}
~manage_generic_ptr()
{
if(data != nullptr)
deleter(data);
}
void* data = nullptr;
C copier = nullptr;
D deleter = nullptr;
};
extern "C" struct migraphx_shape; extern "C" struct migraphx_shape;
struct migraphx_shape struct migraphx_shape
{ {
template <class... Ts> template <class... Ts>
migraphx_shape(Ts&&... xs) : object(std::forward<Ts>(xs)...) migraphx_shape(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{ {
} }
migraphx::shape object; migraphx::shape object;
...@@ -253,7 +335,8 @@ extern "C" struct migraphx_argument; ...@@ -253,7 +335,8 @@ extern "C" struct migraphx_argument;
struct migraphx_argument struct migraphx_argument
{ {
template <class... Ts> template <class... Ts>
migraphx_argument(Ts&&... xs) : object(std::forward<Ts>(xs)...) migraphx_argument(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{ {
} }
migraphx::argument object; migraphx::argument object;
...@@ -263,7 +346,8 @@ extern "C" struct migraphx_target; ...@@ -263,7 +346,8 @@ extern "C" struct migraphx_target;
struct migraphx_target struct migraphx_target
{ {
template <class... Ts> template <class... Ts>
migraphx_target(Ts&&... xs) : object(std::forward<Ts>(xs)...) migraphx_target(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{ {
} }
migraphx::target object; migraphx::target object;
...@@ -273,7 +357,8 @@ extern "C" struct migraphx_program_parameter_shapes; ...@@ -273,7 +357,8 @@ extern "C" struct migraphx_program_parameter_shapes;
struct migraphx_program_parameter_shapes struct migraphx_program_parameter_shapes
{ {
template <class... Ts> template <class... Ts>
migraphx_program_parameter_shapes(Ts&&... xs) : object(std::forward<Ts>(xs)...) migraphx_program_parameter_shapes(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{ {
} }
std::unordered_map<std::string, migraphx::shape> object; std::unordered_map<std::string, migraphx::shape> object;
...@@ -283,7 +368,8 @@ extern "C" struct migraphx_program_parameters; ...@@ -283,7 +368,8 @@ extern "C" struct migraphx_program_parameters;
struct migraphx_program_parameters struct migraphx_program_parameters
{ {
template <class... Ts> template <class... Ts>
migraphx_program_parameters(Ts&&... xs) : object(std::forward<Ts>(xs)...) migraphx_program_parameters(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{ {
} }
std::unordered_map<std::string, migraphx::argument> object; std::unordered_map<std::string, migraphx::argument> object;
...@@ -293,7 +379,8 @@ extern "C" struct migraphx_arguments; ...@@ -293,7 +379,8 @@ extern "C" struct migraphx_arguments;
struct migraphx_arguments struct migraphx_arguments
{ {
template <class... Ts> template <class... Ts>
migraphx_arguments(Ts&&... xs) : object(std::forward<Ts>(xs)...) migraphx_arguments(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{ {
} }
std::vector<migraphx::argument> object; std::vector<migraphx::argument> object;
...@@ -303,7 +390,8 @@ extern "C" struct migraphx_shapes; ...@@ -303,7 +390,8 @@ extern "C" struct migraphx_shapes;
struct migraphx_shapes struct migraphx_shapes
{ {
template <class... Ts> template <class... Ts>
migraphx_shapes(Ts&&... xs) : object(std::forward<Ts>(xs)...) migraphx_shapes(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{ {
} }
std::vector<migraphx::shape> object; std::vector<migraphx::shape> object;
...@@ -343,7 +431,8 @@ extern "C" struct migraphx_module; ...@@ -343,7 +431,8 @@ extern "C" struct migraphx_module;
struct migraphx_module struct migraphx_module
{ {
template <class... Ts> template <class... Ts>
migraphx_module(Ts&&... xs) : object(std::forward<Ts>(xs)...) migraphx_module(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{ {
} }
migraphx::module object; migraphx::module object;
...@@ -353,7 +442,8 @@ extern "C" struct migraphx_program; ...@@ -353,7 +442,8 @@ extern "C" struct migraphx_program;
struct migraphx_program struct migraphx_program
{ {
template <class... Ts> template <class... Ts>
migraphx_program(Ts&&... xs) : object(std::forward<Ts>(xs)...) migraphx_program(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{ {
} }
migraphx::program object; migraphx::program object;
...@@ -363,7 +453,8 @@ extern "C" struct migraphx_operation; ...@@ -363,7 +453,8 @@ extern "C" struct migraphx_operation;
struct migraphx_operation struct migraphx_operation
{ {
template <class... Ts> template <class... Ts>
migraphx_operation(Ts&&... xs) : object(std::forward<Ts>(xs)...) migraphx_operation(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{ {
} }
migraphx::operation object; migraphx::operation object;
...@@ -373,7 +464,8 @@ extern "C" struct migraphx_onnx_options; ...@@ -373,7 +464,8 @@ extern "C" struct migraphx_onnx_options;
struct migraphx_onnx_options struct migraphx_onnx_options
{ {
template <class... Ts> template <class... Ts>
migraphx_onnx_options(Ts&&... xs) : object(std::forward<Ts>(xs)...) migraphx_onnx_options(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{ {
} }
migraphx::onnx_options object; migraphx::onnx_options object;
...@@ -383,7 +475,8 @@ extern "C" struct migraphx_file_options; ...@@ -383,7 +475,8 @@ extern "C" struct migraphx_file_options;
struct migraphx_file_options struct migraphx_file_options
{ {
template <class... Ts> template <class... Ts>
migraphx_file_options(Ts&&... xs) : object(std::forward<Ts>(xs)...) migraphx_file_options(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{ {
} }
migraphx::file_options object; migraphx::file_options object;
...@@ -393,7 +486,8 @@ extern "C" struct migraphx_compile_options; ...@@ -393,7 +486,8 @@ extern "C" struct migraphx_compile_options;
struct migraphx_compile_options struct migraphx_compile_options
{ {
template <class... Ts> template <class... Ts>
migraphx_compile_options(Ts&&... xs) : object(std::forward<Ts>(xs)...) migraphx_compile_options(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{ {
} }
migraphx::compile_options object; migraphx::compile_options object;
...@@ -403,7 +497,8 @@ extern "C" struct migraphx_tf_options; ...@@ -403,7 +497,8 @@ extern "C" struct migraphx_tf_options;
struct migraphx_tf_options struct migraphx_tf_options
{ {
template <class... Ts> template <class... Ts>
migraphx_tf_options(Ts&&... xs) : object(std::forward<Ts>(xs)...) migraphx_tf_options(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{ {
} }
migraphx::tf_options object; migraphx::tf_options object;
...@@ -413,7 +508,8 @@ extern "C" struct migraphx_quantize_op_names; ...@@ -413,7 +508,8 @@ extern "C" struct migraphx_quantize_op_names;
struct migraphx_quantize_op_names struct migraphx_quantize_op_names
{ {
template <class... Ts> template <class... Ts>
migraphx_quantize_op_names(Ts&&... xs) : object(std::forward<Ts>(xs)...) migraphx_quantize_op_names(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{ {
} }
std::vector<std::string> object; std::vector<std::string> object;
...@@ -423,7 +519,8 @@ extern "C" struct migraphx_quantize_int8_options; ...@@ -423,7 +519,8 @@ extern "C" struct migraphx_quantize_int8_options;
struct migraphx_quantize_int8_options struct migraphx_quantize_int8_options
{ {
template <class... Ts> template <class... Ts>
migraphx_quantize_int8_options(Ts&&... xs) : object(std::forward<Ts>(xs)...) migraphx_quantize_int8_options(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{ {
} }
migraphx::quantize_int8_options object; migraphx::quantize_int8_options object;
...@@ -433,12 +530,41 @@ extern "C" struct migraphx_context; ...@@ -433,12 +530,41 @@ extern "C" struct migraphx_context;
struct migraphx_context struct migraphx_context
{ {
template <class... Ts> template <class... Ts>
migraphx_context(Ts&&... xs) : object(std::forward<Ts>(xs)...) migraphx_context(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{ {
} }
migraphx::context object; migraphx::context object;
}; };
extern "C" struct migraphx_experimental_custom_op;
struct migraphx_experimental_custom_op
{
template <class... Ts>
migraphx_experimental_custom_op(void* p,
migraphx_experimental_custom_op_copy c,
migraphx_experimental_custom_op_delete d,
Ts&&... xs)
: object_ptr(p, c, d), xobject(std::forward<Ts>(xs)...)
{
}
manage_generic_ptr<migraphx_experimental_custom_op_copy, migraphx_experimental_custom_op_delete>
object_ptr = nullptr;
migraphx::experimental_custom_op xobject;
migraphx_experimental_custom_op_compute_shape compute_shape_f = nullptr;
migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
{
std::remove_pointer_t<migraphx_shape_t> out;
if(compute_shape_f == nullptr)
throw std::runtime_error("compute_shape function is missing.");
auto api_error_result =
compute_shape_f(&out, object_ptr.data, object_cast<migraphx_shapes_t>(&(inputs)));
if(api_error_result != migraphx_status_success)
throw std::runtime_error("Error in compute_shape.");
return (&out)->object;
}
};
extern "C" migraphx_status migraphx_shape_destroy(migraphx_shape_t shape) extern "C" migraphx_status migraphx_shape_destroy(migraphx_shape_t shape)
{ {
auto api_error_result = migraphx::try_([&] { destroy((shape)); }); auto api_error_result = migraphx::try_([&] { destroy((shape)); });
...@@ -1564,3 +1690,51 @@ extern "C" migraphx_status migraphx_context_finish(const_migraphx_context_t cont ...@@ -1564,3 +1690,51 @@ extern "C" migraphx_status migraphx_context_finish(const_migraphx_context_t cont
}); });
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status
migraphx_experimental_custom_op_destroy(migraphx_experimental_custom_op_t experimental_custom_op)
{
auto api_error_result = migraphx::try_([&] { destroy((experimental_custom_op)); });
return api_error_result;
}
extern "C" migraphx_status
migraphx_experimental_custom_op_assign_to(migraphx_experimental_custom_op_t output,
const_migraphx_experimental_custom_op_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status
migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experimental_custom_op,
void* obj,
migraphx_experimental_custom_op_copy c,
migraphx_experimental_custom_op_delete d,
const char* name)
{
auto api_error_result = migraphx::try_([&] {
*experimental_custom_op =
allocate<migraphx_experimental_custom_op_t>((obj), (c), (d), (name));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_experimental_custom_op_set_compute_shape(
migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_compute_shape input)
{
auto api_error_result = migraphx::try_([&] { (obj)->compute_shape_f = (input); });
return api_error_result;
}
extern "C" migraphx_status
migraphx_experimental_custom_op_register(migraphx_experimental_custom_op_t experimental_custom_op)
{
auto api_error_result = migraphx::try_([&] {
if(experimental_custom_op == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter experimental_custom_op: Null pointer");
migraphx::register_custom_op((*experimental_custom_op));
});
return api_error_result;
}
...@@ -103,6 +103,17 @@ typedef const struct migraphx_quantize_int8_options* const_migraphx_quantize_int ...@@ -103,6 +103,17 @@ typedef const struct migraphx_quantize_int8_options* const_migraphx_quantize_int
typedef struct migraphx_context* migraphx_context_t; typedef struct migraphx_context* migraphx_context_t;
typedef const struct migraphx_context* const_migraphx_context_t; typedef const struct migraphx_context* const_migraphx_context_t;
typedef struct migraphx_experimental_custom_op* migraphx_experimental_custom_op_t;
typedef const struct migraphx_experimental_custom_op* const_migraphx_experimental_custom_op_t;
typedef migraphx_status (*migraphx_experimental_custom_op_compute_shape)(migraphx_shape_t out,
void* obj,
migraphx_shapes_t inputs);
typedef migraphx_status (*migraphx_experimental_custom_op_copy)(void** out, void* input);
typedef migraphx_status (*migraphx_experimental_custom_op_delete)(void* input);
migraphx_status migraphx_shape_destroy(migraphx_shape_t shape); migraphx_status migraphx_shape_destroy(migraphx_shape_t shape);
migraphx_status migraphx_shape_assign_to(migraphx_shape_t output, const_migraphx_shape_t input); migraphx_status migraphx_shape_assign_to(migraphx_shape_t output, const_migraphx_shape_t input);
...@@ -422,6 +433,26 @@ migraphx_status migraphx_quantize_int8(migraphx_program_t prog, ...@@ -422,6 +433,26 @@ migraphx_status migraphx_quantize_int8(migraphx_program_t prog,
migraphx_status migraphx_context_finish(const_migraphx_context_t context); migraphx_status migraphx_context_finish(const_migraphx_context_t context);
migraphx_status
migraphx_experimental_custom_op_destroy(migraphx_experimental_custom_op_t experimental_custom_op);
migraphx_status
migraphx_experimental_custom_op_assign_to(migraphx_experimental_custom_op_t output,
const_migraphx_experimental_custom_op_t input);
migraphx_status
migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experimental_custom_op,
void* obj,
migraphx_experimental_custom_op_copy c,
migraphx_experimental_custom_op_delete d,
const char* name);
migraphx_status migraphx_experimental_custom_op_set_compute_shape(
migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_compute_shape input);
migraphx_status
migraphx_experimental_custom_op_register(migraphx_experimental_custom_op_t experimental_custom_op);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif
......
...@@ -15,6 +15,16 @@ namespace migraphx { ...@@ -15,6 +15,16 @@ namespace migraphx {
inline namespace api { // NOLINT inline namespace api { // NOLINT
#endif #endif
template <int N>
struct rank : rank<N - 1>
{
};
template <>
struct rank<0>
{
};
template <class T, class F, class... Ts> template <class T, class F, class... Ts>
T* make(F f, Ts&&... xs) T* make(F f, Ts&&... xs)
{ {
...@@ -231,6 +241,138 @@ struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>> ...@@ -231,6 +241,138 @@ struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>>
std::shared_ptr<T> m_handle; std::shared_ptr<T> m_handle;
}; };
template <class Base>
struct interface_base : Base
{
interface_base() : Base() {}
protected:
template <class F>
static migraphx_status try_(F f) // NOLINT
{
try
{
f();
return migraphx_status_success;
}
catch(...)
{
return migraphx_status_unknown_error;
}
}
template <class F, class T, class... Ts>
void make_interface(F f, T& obj, Ts&&... xs)
{
auto copy = [](void** out, void* input) {
return try_([&] {
T** y = reinterpret_cast<T**>(out);
T* x = reinterpret_cast<T*>(input);
assert(x != nullptr and y != nullptr and *y == nullptr);
*y = new T(*x); // NOLINT
});
};
auto del = [](void* input) {
return try_([&] {
T* x = reinterpret_cast<T*>(input);
delete x; // NOLINT
});
};
this->make_handle(f, &obj, copy, del, std::forward<Ts>(xs)...);
}
template <class T, class Setter, class F>
void set_fp(Setter setter, F pf)
{
static F f = pf;
(void)f; // avoid warning on gcc
call(setter, this->get_handle_ptr(), [](auto... xs) -> migraphx_status {
return try_([&] { call_cast_arg<T>(rank<1>{}, f, xs...); });
});
}
template <class T, class Setter, class F>
void set_auto_fp(Setter setter, F f)
{
return set_fp<T>(setter, [=](T& obj, auto out, auto... xs) {
auto_invoke(f, out, obj, auto_convert_param(rank<2>{}, xs)...);
});
}
struct no_out_arg
{
};
template <class T, class F, class X, class... Xs, class = std::enable_if_t<std::is_void<X>{}>>
static void call_cast_arg(rank<0>, F f, X* obj, Xs... xs)
{
f(reinterpret_cast<T*>(obj), no_out_arg{}, xs...);
}
template <class T,
class F,
class R,
class X,
class... Xs,
class = std::enable_if_t<std::is_void<X>{}>>
static void call_cast_arg(rank<1>, F f, R result, X* obj, Xs... xs)
{
f(*reinterpret_cast<T*>(obj), result, xs...);
}
template <class F, class T, class... Ts>
void auto_invoke(F f, T* out, Ts&&... xs)
{
auto_assign(rank<2>{}, out, f(std::forward<Ts>(xs)...));
}
template <class F, class T, class... Ts>
void auto_invoke(F f, no_out_arg, Ts&&... xs)
{
f(std::forward<Ts>(xs)...);
}
template <class T, class = std::enable_if_t<std::is_fundamental<T>{} or std::is_enum<T>{}>>
T auto_convert_param(rank<0>, T x)
{
return x;
}
template <class T>
auto auto_convert_param(rank<1>, T x) -> decltype(as_handle<T>{x})
{
return as_handle<T>{x};
}
template <class T>
auto auto_convert_param(rank<2>, T x) -> decltype(as_handle<T>{x, borrow{}})
{
return as_handle<T>{x, borrow{}};
}
template <class T, class U>
void auto_assign(rank<0>, T* out, U x)
{
return *out = x;
}
template <class T, class U>
auto auto_assign(rank<1>, T* out, U x) -> decltype(x.assign_to_handle(out))
{
x.assign_to_handle(out);
}
};
// NOLINTNEXTLINE
#define MIGRAPHX_INTERFACE_LIFT(T, prefix, name) \
this->set_auto_fp<T>(&migraphx_##prefix##_set_##name, \
[](T& x, auto... xs) { return x.name(xs...); })
template <class Base, class T>
using require_interface =
std::enable_if_t<std::is_base_of<Base, T>{} and not std::is_same<T, Base>{} and
std::is_copy_constructible<T>{} and std::is_final<T>{}>;
#ifdef DOXYGEN #ifdef DOXYGEN
#define MIGRAPHX_DETAIL_HANDLE_BASE(name, const_) handle_base<> #define MIGRAPHX_DETAIL_HANDLE_BASE(name, const_) handle_base<>
#else #else
...@@ -988,6 +1130,32 @@ quantize_int8(const program& prog, const target& ptarget, const quantize_int8_op ...@@ -988,6 +1130,32 @@ quantize_int8(const program& prog, const target& ptarget, const quantize_int8_op
options.get_handle_ptr()); options.get_handle_ptr());
} }
struct experimental_custom_op_base
{
virtual std::string name() const = 0;
virtual shape compute_shape(shapes inputs) const = 0;
virtual ~experimental_custom_op_base() = default;
};
struct experimental_custom_op : interface_base<MIGRAPHX_HANDLE_BASE(experimental_custom_op)>
{
template <class T>
experimental_custom_op(T& obj)
{
this->make_interface(&migraphx_experimental_custom_op_create, obj, obj.name().c_str());
MIGRAPHX_INTERFACE_LIFT(T, experimental_custom_op, compute_shape);
}
void register_op() { call(&migraphx_experimental_custom_op_register, this->get_handle_ptr()); }
};
template <class T, class = require_interface<experimental_custom_op_base, T>>
void register_experimental_custom_op(T& obj)
{
experimental_custom_op op{obj};
op.register_op();
}
#ifndef DOXYGEN #ifndef DOXYGEN
} // namespace api } // namespace api
#endif #endif
......
...@@ -403,3 +403,13 @@ api.add_function('migraphx_quantize_int8', ...@@ -403,3 +403,13 @@ api.add_function('migraphx_quantize_int8',
@auto_handle(ref=True) @auto_handle(ref=True)
def context(h): def context(h):
h.method('finish', const=True) h.method('finish', const=True)
@api.interface('migraphx_experimental_custom_op',
'migraphx::experimental_custom_op')
def experimental_custom_op(h):
h.constructor('create', api.params(name='const char*'))
h.virtual('compute_shape',
api.params(inputs='std::vector<migraphx::shape>'),
returns='migraphx::shape')
h.method('register', invoke='migraphx::register_custom_op($@)')
...@@ -34,7 +34,14 @@ std::vector<char> src_compiler::compile(const std::vector<src_file>& srcs) const ...@@ -34,7 +34,14 @@ std::vector<char> src_compiler::compile(const std::vector<src_file>& srcs) const
params += " -o " + out; params += " -o " + out;
td.execute(compiler, params); if(not launcher.empty())
{
td.execute(launcher, compiler + " " + params);
}
else
{
td.execute(compiler, params);
}
auto out_path = td.path / out; auto out_path = td.path / out;
if(not fs::exists(out_path)) if(not fs::exists(out_path))
......
...@@ -32,18 +32,22 @@ struct allocation_model ...@@ -32,18 +32,22 @@ struct allocation_model
#else #else
/* #ifdef TYPE_ERASED_DECLARATION
* Type-erased interface for:
* // Type-erased interface for:
* struct allocation_model struct allocation_model
* { {
* std::string name() const; //
* std::string copy() const; std::string name() const;
* operation allocate(const shape& s) const; //
* operation preallocate(const shape& s,std::string id) const; std::string copy() const;
* }; //
* operation allocate(const shape& s) const;
*/ //
operation preallocate(const shape& s, std::string id) const;
};
#else
struct allocation_model struct allocation_model
{ {
...@@ -260,6 +264,7 @@ inline const ValueType& any_cast(const allocation_model& x) ...@@ -260,6 +264,7 @@ inline const ValueType& any_cast(const allocation_model& x)
throw std::bad_cast(); throw std::bad_cast();
return *y; return *y;
} }
#endif
#endif #endif
......
...@@ -23,6 +23,7 @@ struct src_compiler ...@@ -23,6 +23,7 @@ struct src_compiler
std::string compiler = "c++"; std::string compiler = "c++";
std::string flags = ""; std::string flags = "";
std::string output = ""; std::string output = "";
std::string launcher = "";
std::function<fs::path(fs::path)> process = nullptr; std::function<fs::path(fs::path)> process = nullptr;
std::vector<char> compile(const std::vector<src_file>& srcs) const; std::vector<char> compile(const std::vector<src_file>& srcs) const;
}; };
......
...@@ -30,17 +30,20 @@ struct concat_optimization ...@@ -30,17 +30,20 @@ struct concat_optimization
#else #else
/* #ifdef TYPE_ERASED_DECLARATION
* Type-erased interface for:
* // Type-erased interface for:
* struct concat_optimization struct concat_optimization
* { {
* std::string name() const; //
* std::string allocate() const; std::string name() const;
* op::concat get_concat(const operation& op) const; //
* }; std::string allocate() const;
* //
*/ op::concat get_concat(const operation& op) const;
};
#else
struct concat_optimization struct concat_optimization
{ {
...@@ -244,6 +247,7 @@ inline const ValueType& any_cast(const concat_optimization& x) ...@@ -244,6 +247,7 @@ inline const ValueType& any_cast(const concat_optimization& x)
throw std::bad_cast(); throw std::bad_cast();
return *y; return *y;
} }
#endif
#endif #endif
......
...@@ -44,18 +44,22 @@ any_ptr get_queue_context(T&) ...@@ -44,18 +44,22 @@ any_ptr get_queue_context(T&)
return {}; return {};
} }
/* #ifdef TYPE_ERASED_DECLARATION
* Type-erased interface for:
* // Type-erased interface for:
* struct context struct context
* { {
* value to_value() const; // (optional)
* void from_value(const value& v) ; value to_value() const;
* any_ptr get_queue() ; // (optional)
* void finish() const; void from_value(const value& v);
* }; // (optional)
* any_ptr get_queue();
*/ //
void finish() const;
};
#else
struct context struct context
{ {
...@@ -316,6 +320,7 @@ inline const ValueType& any_cast(const context& x) ...@@ -316,6 +320,7 @@ inline const ValueType& any_cast(const context& x)
throw std::bad_cast(); throw std::bad_cast();
return *y; return *y;
} }
#endif
inline void migraphx_to_value(value& v, const context& ctx) { v = ctx.to_value(); } inline void migraphx_to_value(value& v, const context& ctx) { v = ctx.to_value(); }
inline void migraphx_from_value(const value& v, context& ctx) { ctx.from_value(v); } inline void migraphx_from_value(const value& v, context& ctx) { ctx.from_value(v); }
......
...@@ -20,18 +20,22 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -20,18 +20,22 @@ inline namespace MIGRAPHX_INLINE_NS {
#else #else
/* #ifdef TYPE_ERASED_DECLARATION
* Type-erased interface for:
* // Type-erased interface for:
* struct marker struct marker
* { {
* void mark_start(instruction_ref ins_ref) ; //
* void mark_start(const program& prog) ; void mark_start(instruction_ref ins_ref);
* void mark_stop(instruction_ref ins) ; //
* void mark_stop(const program& prog) ; void mark_start(const program& prog);
* }; //
* void mark_stop(instruction_ref ins);
*/ //
void mark_stop(const program& prog);
};
#else
struct marker struct marker
{ {
...@@ -243,6 +247,7 @@ inline const ValueType& any_cast(const marker& x) ...@@ -243,6 +247,7 @@ inline const ValueType& any_cast(const marker& x)
throw std::bad_cast(); throw std::bad_cast();
return *y; return *y;
} }
#endif
#endif #endif
......
...@@ -43,7 +43,7 @@ struct roialign ...@@ -43,7 +43,7 @@ struct roialign
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(3).standard(); check_shapes{inputs, *this}.has(3);
auto x_lens = inputs.at(0).lens(); auto x_lens = inputs.at(0).lens();
auto roi_lens = inputs.at(1).lens(); auto roi_lens = inputs.at(1).lens();
auto bi_lens = inputs.at(2).lens(); auto bi_lens = inputs.at(2).lens();
......
...@@ -445,35 +445,62 @@ lifetime get_lifetime_op(const T&) ...@@ -445,35 +445,62 @@ lifetime get_lifetime_op(const T&)
} // namespace detail } // namespace detail
/* #ifdef TYPE_ERASED_DECLARATION
* Type-erased interface for:
* // Type-erased interface for:
* struct operation struct operation
* { {
* std::string name() const; //
* bool is_context_free() const; std::string name() const;
* bool need_normalization() const; // (optional)
* bool has_finalize() const; bool is_context_free() const;
* lifetime get_lifetime() const; // (optional)
* std::ptrdiff_t output_alias(const std::vector<shape>& input) const; bool need_normalization() const;
* value compile(context& ctx,const shape& output,const std::vector<shape>& input) ; // (optional)
* void finalize(context& ctx,const shape& output,const std::vector<shape>& input) ; bool has_finalize() const;
* shape compute_shape(const std::vector<shape>& input) const; // (optional)
* shape compute_shape(const std::vector<shape>& inputs,const std::vector<module_ref>& lifetime get_lifetime() const;
* mod_args) const; argument compute(context& ctx,const shape& output,const std::vector<argument>& // (optional)
* input) const; argument compute(const shape& output,const std::vector<argument>& input) const; std::ptrdiff_t output_alias(const std::vector<shape>& input) const;
* argument compute(const shape& output,const std::vector<argument>& input,const // (optional)
* std::vector<module_ref>& module_args,std::function<std::vector<argument>(module_ref&, const value compile(context& ctx, const shape& output, const std::vector<shape>& input);
* std::unordered_map<std::string, argument>&)> run) const; argument compute(context& ctx,const // (optional)
* shape& output,const std::vector<argument>& input,const std::vector<module_ref>& void finalize(context& ctx, const shape& output, const std::vector<shape>& input);
* module_args,std::function<std::vector<argument>(module_ref&, const // (optional)
* std::unordered_map<std::string, argument>&)> run) const; value to_value() const; void shape compute_shape(const std::vector<shape>& input) const;
* from_value(const value& v) ; value attributes() const; friend std::ostream & // (optional)
* operator<<(std::ostream & os,const operation & op) ; friend bool operator==(const operation & shape compute_shape(const std::vector<shape>& inputs,
* x,const operation & y) ; const std::vector<module_ref>& mod_args) const;
* }; // (optional)
* argument compute(context& ctx, const shape& output, const std::vector<argument>& input) const;
*/ // (optional)
argument compute(const shape& output, const std::vector<argument>& input) const;
// (optional)
argument compute(const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)> run) const;
// (optional)
argument compute(context& ctx,
const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)> run) const;
// (optional)
value to_value() const;
// (optional)
void from_value(const value& v);
// (optional)
value attributes() const;
//
friend std::ostream& operator<<(std::ostream& os, const operation& op);
//
friend bool operator==(const operation& x, const operation& y);
};
#else
struct operation struct operation
{ {
...@@ -1222,6 +1249,7 @@ inline const ValueType& any_cast(const operation& x) ...@@ -1222,6 +1249,7 @@ inline const ValueType& any_cast(const operation& x)
throw std::bad_cast(); throw std::bad_cast();
return *y; return *y;
} }
#endif
inline bool operator!=(const operation& x, const operation& y) { return !(x == y); } inline bool operator!=(const operation& x, const operation& y) { return !(x == y); }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment