"test/vscode:/vscode.git/clone" did not exist on "b47d3a8528966f71d72dfdbc2f033eebb9d8f280"
Commit 1a096ad1 authored by Paul's avatar Paul
Browse files

Merge

parents 34a51892 cb801f60
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
cmake_minimum_required(VERSION 3.5)
project (custom_hip_kernel)
set (CMAKE_CXX_STANDARD 14)
set (EXAMPLE custom_op_hip_kernel)
list (APPEND CMAKE_PREFIX_PATH /opt/rocm/hip /opt/rocm)
find_package (migraphx REQUIRED)
find_package (hip REQUIRED)
message("source file: " ${EXAMPLE}.cpp " ---> bin: " ${EXAMPLE})
add_executable(${EXAMPLE} ${EXAMPLE}.cpp)
target_link_libraries(${EXAMPLE} migraphx::c hip::device)
# Custom Kernel using MIGraphX API.
This is an example of a custom operator implementation using MIGraphX's C/C++ APIs. It also demonstrates how to use this custom op in conjunction with rest of MIGraphX operators to build and run MIGraphX program on GPU.
Kernels can be written in either HIP, MIOpen, or by using RocBLAS library. This particular example uses **HIP**.
To build the example, ensure ROCm is installed at `/opt/rocm`.
1. `export LD_LIBRARY_PATH=/opt/rocm/lib:$LD_LIBRARY_PATH`
2. `cd $MIGRAPHX_SRC/examples/migraphx/custom_op_hip_kernel/`
3. `mkdir build && cd build`
4. `CXX=/opt/rocm/llvm/bin/clang++ cmake .. && make`
5. `./custom_op_hip_kernel`
\ No newline at end of file
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <algorithm>
#include <hip/hip_runtime.h>
#include <migraphx/migraphx.hpp> // MIGraphX's C++ API
#include <numeric>
#define MIGRAPHX_HIP_ASSERT(x) (assert((x) == hipSuccess))
/*
* Square each element in the array A and write to array C.
*/
template <typename T>
__global__ void vector_square(T* C_d, const T* A_d, size_t N)
{
size_t offset = (hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x);
size_t stride = hipBlockDim_x * hipGridDim_x;
for(size_t i = offset; i < N; i += stride)
{
C_d[i] = A_d[i] * A_d[i];
}
}
struct square_custom_op final : migraphx::experimental_custom_op_base
{
virtual std::string name() const override { return "square_custom_op"; }
virtual migraphx::argument
compute(migraphx::context ctx, migraphx::shape, migraphx::arguments inputs) const override
{
// if compile options has offload_copy = true then, parameters and outputs will be
// automatically copied to and from GPUs' memory. Here assume that `inputs` arguments are
// already in the GPU, so no need to do Malloc, Free or Memcpy. Last element in the `inputs`
// is output argument, so it should be returned from compute method.
auto* input_buffer = reinterpret_cast<float*>(inputs[0].data());
auto* output_buffer = reinterpret_cast<float*>(inputs[1].data());
size_t n_elements = inputs[0].get_shape().bytes() / sizeof(inputs[0].get_shape().type());
MIGRAPHX_HIP_ASSERT(hipSetDevice(0));
const unsigned blocks = 512;
const unsigned threads_per_block = 256;
// cppcheck-suppress UseDeviceLaunch
hipLaunchKernelGGL(vector_square,
dim3(blocks),
dim3(threads_per_block),
0,
ctx.get_queue<hipStream_t>(),
output_buffer,
input_buffer,
n_elements);
return inputs[1];
}
virtual migraphx::shape compute_shape(migraphx::shapes inputs) const override
{
if(inputs.size() != 2)
{
throw std::runtime_error("square_custom_op must have 2 arguments");
}
if(inputs[0] != inputs[1])
{
throw std::runtime_error("Inputs to the square_custom_op must have same Shape");
}
return inputs.back();
}
};
int main(int argc, const char* argv[])
{
square_custom_op square_op;
migraphx::register_experimental_custom_op(square_op);
migraphx::program p;
migraphx::shape s{migraphx_shape_float_type, {32, 256}};
migraphx::module m = p.get_main_module();
auto x = m.add_parameter("x", s);
auto neg_ins = m.add_instruction(migraphx::operation("neg"), x);
// add allocation for the custom_kernel's output buffer
auto alloc = m.add_allocation(s);
auto custom_kernel =
m.add_instruction(migraphx::operation("square_custom_op"), {neg_ins, alloc});
auto relu_ins = m.add_instruction(migraphx::operation("relu"), {custom_kernel});
m.add_return({relu_ins});
migraphx::compile_options options;
// set offload copy to true for GPUs
options.set_offload_copy();
p.compile(migraphx::target("gpu"), options);
migraphx::program_parameters pp;
std::vector<float> x_data(s.bytes() / sizeof(s.type()));
std::iota(x_data.begin(), x_data.end(), 0);
pp.add("x", migraphx::argument(s, x_data.data()));
auto results = p.eval(pp);
auto result = results[0];
std::vector<float> expected_result = x_data;
std::transform(expected_result.begin(),
expected_result.end(),
expected_result.begin(),
[](auto i) { return std::pow(i, 2); });
if(bool{result == migraphx::argument(s, expected_result.data())})
{
std::cout << "Successfully executed custom HIP kernel example\n";
}
else
{
std::cout << "Custom HIP kernel example failed\n";
}
return 0;
}
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
cmake_minimum_required(VERSION 3.5)
project (custom_miopen_kernel)
set (CMAKE_CXX_STANDARD 14)
set (EXAMPLE custom_op_miopen_kernel)
list (APPEND CMAKE_PREFIX_PATH /opt/rocm/hip /opt/rocm)
find_package (migraphx REQUIRED)
find_package (miopen REQUIRED)
message("source file: " ${EXAMPLE}.cpp " ---> bin: " ${EXAMPLE})
add_executable(${EXAMPLE} ${EXAMPLE}.cpp)
target_link_libraries(${EXAMPLE} migraphx::c MIOpen)
# Custom MIOpen Kernel using MIGraphX API.
This is an example of a custom operator implementation using MIGraphX's C/C++ APIs. It also demonstrates how to use this custom op in conjunction with rest of MIGraphX operators to build and run MIGraphX program on GPU.
Kernels can be written in either HIP, MIOpen, or by using RocBLAS library. This particular example uses **MIOpen** library calls.
To build and run example, ensure ROCm is installed at `/opt/rocm`.
1. `export LD_LIBRARY_PATH=/opt/rocm/lib:$LD_LIBRARY_PATH`
2. `cd $MIGRAPHX_SRC/examples/migraphx/custom_op_miopen_kernel/`
3. `mkdir build && cd build`
4. `cmake .. && make`
5. `./custom_op_miopen_kernel`
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <algorithm>
#include <hip/hip_runtime.h>
#include <migraphx/migraphx.h>
#include <miopen/miopen.h>
#include <migraphx/migraphx.hpp> // MIGraphX's C++ API
#include <numeric>
#include <stdexcept>
#define MIGRAPHX_MIOPEN_ASSERT(x) (assert((x) == miopenStatusSuccess))
#define MIGRAPHX_HIP_ASSERT(x) (assert((x) == hipSuccess))
inline miopenTensorDescriptor_t make_miopen_tensor(const migraphx::shape& s, bool pack = false)
{
miopenTensorDescriptor_t t;
MIGRAPHX_MIOPEN_ASSERT(miopenCreateTensorDescriptor(&t));
// Convert to ints
auto s_lens = s.lengths();
std::vector<int> lens(s_lens.begin(), s_lens.end());
auto s_strides = s.strides();
std::vector<int> strides(s_strides.begin(), s_strides.end());
miopenDataType_t d;
if(s.type() == migraphx_shape_float_type)
d = miopenFloat;
else if(s.type() == migraphx_shape_half_type)
d = miopenHalf;
else if(s.type() == migraphx_shape_int32_type)
d = miopenInt32;
else if(s.type() == migraphx_shape_int8_type)
{
if(pack)
{
// update the lens and corresponding strides
d = miopenInt8x4;
lens[1] = ((lens[1] + 3) / 4) * 4;
strides[0] = strides[1] * lens[1];
}
else
{
d = miopenInt8;
}
}
else
{
throw("MAKE_TENSOR: unsupported type");
}
miopenSetTensorDescriptor(t, d, s_lens.size(), lens.data(), strides.data());
return t;
}
inline auto make_miopen_handle(migraphx::context& ctx)
{
MIGRAPHX_HIP_ASSERT(hipSetDevice(0));
auto* stream = ctx.get_queue<hipStream_t>();
miopenHandle_t out;
MIGRAPHX_MIOPEN_ASSERT(miopenCreateWithStream(&out, stream));
return out;
}
inline auto make_activation_descriptor(miopenActivationMode_t mode,
double alpha = 0,
double beta = 0,
double gamma = 0)
{
miopenActivationDescriptor_t ad;
MIGRAPHX_MIOPEN_ASSERT(miopenCreateActivationDescriptor(&ad));
miopenSetActivationDescriptor(ad, mode, alpha, beta, gamma);
return ad;
}
struct abs_custom_op final : migraphx::experimental_custom_op_base
{
virtual std::string name() const override { return "abs_custom_op"; }
virtual migraphx::argument compute(migraphx::context ctx,
migraphx::shape output_shape,
migraphx::arguments args) const override
{
float alpha = 1;
float beta = 0;
// MIOpen kernel call takes raw buffer pointers for the TensorData. These Buffer pointers
// must be accompanied with Tensor Description e.g. shape, type, strides, dimensionality.
// Following `make_miopen_tensor` makes such tensor descriptors to pass as parameter to
// MIOpen kernel call.
auto y_desc = make_miopen_tensor(output_shape);
auto x_desc = make_miopen_tensor(args[0].get_shape());
// create MIOpen stream handle
auto miopen_handle = make_miopen_handle(ctx);
// MIOpen has generic kernel for many different kinds of activation functions.
// Each such generic call must be accompanied with description of what kind of activation
// computation to perform
auto ad = make_activation_descriptor(miopenActivationABS, 0, 0, 0);
miopenActivationForward(
miopen_handle, ad, &alpha, x_desc, args[0].data(), &beta, y_desc, args[1].data());
return args[1];
}
virtual migraphx::shape compute_shape(migraphx::shapes inputs) const override
{
if(inputs.size() != 2)
{
throw std::runtime_error("abs_custom_op must have two input arguments");
}
if(inputs[0] != inputs[1])
{
throw std::runtime_error("Input arguments to abs_custom_op must have same shape");
}
return inputs.back();
}
};
int main(int argc, const char* argv[])
{
abs_custom_op abs_op;
migraphx::register_experimental_custom_op(abs_op);
migraphx::program p;
migraphx::shape s{migraphx_shape_float_type, {32, 256}};
migraphx::module m = p.get_main_module();
auto x = m.add_parameter("x", s);
auto neg_ins = m.add_instruction(migraphx::operation("neg"), {x});
// add allocation for the custom_kernel's output buffer
auto alloc = m.add_allocation(s);
auto custom_kernel = m.add_instruction(migraphx::operation("abs_custom_op"), {neg_ins, alloc});
auto relu_ins = m.add_instruction(migraphx::operation("relu"), {custom_kernel});
m.add_return({relu_ins});
migraphx::compile_options options;
// set offload copy to true for GPUs
options.set_offload_copy();
p.compile(migraphx::target("gpu"), options);
migraphx::program_parameters prog_params;
std::vector<float> x_data(s.bytes() / sizeof(s.type()));
std::iota(x_data.begin(), x_data.end(), 0);
prog_params.add("x", migraphx::argument(s, x_data.data()));
auto results = p.eval(prog_params);
auto result = results[0];
std::vector<float> expected_result = x_data;
std::transform(expected_result.begin(),
expected_result.end(),
expected_result.begin(),
[](auto i) { return std::abs(i); });
if(bool{result == migraphx::argument(s, expected_result.data())})
{
std::cout << "Successfully executed custom MIOpen kernel example with MIGraphX\n";
}
else
{
std::cout << "Custom MIOpen kernel example failed\n";
}
return 0;
}
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
cmake_minimum_required(VERSION 3.5)
project (custom_rocblas_kernel)
set (CMAKE_CXX_STANDARD 14)
set (EXAMPLE custom_op_rocblas_kernel)
list (APPEND CMAKE_PREFIX_PATH /opt/rocm/hip /opt/rocm)
find_package (migraphx REQUIRED)
find_package (rocblas REQUIRED)
message("source file: " ${EXAMPLE}.cpp " ---> bin: " ${EXAMPLE})
add_executable(${EXAMPLE} ${EXAMPLE}.cpp)
target_link_libraries(${EXAMPLE} migraphx::c roc::rocblas)
# Custom rocBLAS Kernel using MIGraphX API.
This is an example of a custom operator implementation using MIGraphX's C/C++ APIs. It also demonstrates how to use this custom op in conjunction with rest of MIGraphX operators to build and run MIGraphX program on GPU.
Kernels can be written in either HIP, MIOpen, or by using RocBLAS library. This particular example uses **rocBLAS** library calls.
To build and run the example, ensure ROCm is installed at `/opt/rocm`.
1. `export LD_LIBRARY_PATH=/opt/rocm/lib:$LD_LIBRARY_PATH`
2. `cd $MIGRAPHX_SRC/examples/migraphx/custom_op_rocblas_kernel/`
3. `mkdir build && cd build`
4. `cmake .. && make`
5. `./custom_op_rocblas_kernel`
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <algorithm>
#include <hip/hip_runtime.h>
#include <rocblas.h>
#include <migraphx/migraphx.h>
#include <migraphx/migraphx.hpp> // MIGraphX's C++ API
#include <numeric>
#include <stdexcept>
#define MIGRAPHX_ROCBLAS_ASSERT(x) (assert((x) == rocblas_status::rocblas_status_success))
#define MIGRAPHX_HIP_ASSERT(x) (assert((x) == hipSuccess))
rocblas_handle create_rocblas_handle_ptr()
{
rocblas_handle handle;
MIGRAPHX_ROCBLAS_ASSERT(rocblas_create_handle(&handle));
return rocblas_handle{handle};
}
rocblas_handle create_rocblas_handle_ptr(migraphx::context& ctx)
{
MIGRAPHX_HIP_ASSERT(hipSetDevice(0));
rocblas_handle rb = create_rocblas_handle_ptr();
auto* stream = ctx.get_queue<hipStream_t>();
MIGRAPHX_ROCBLAS_ASSERT(rocblas_set_stream(rb, stream));
return rb;
}
struct sscal_custom_op final : migraphx::experimental_custom_op_base
{
virtual std::string name() const override { return "sscal_custom_op"; }
virtual migraphx::argument compute(migraphx::context ctx,
migraphx::shape output_shape,
migraphx::arguments args) const override
{
// create rocblas stream handle
auto rocblas_handle = create_rocblas_handle_ptr(ctx);
rocblas_int n = args[1].get_shape().lengths()[0];
float* alpha = reinterpret_cast<float*>(args[0].data());
float* vec_ptr = reinterpret_cast<float*>(args[1].data());
MIGRAPHX_ROCBLAS_ASSERT(rocblas_sscal(rocblas_handle, n, alpha, vec_ptr, 1));
return args[1];
}
virtual migraphx::shape compute_shape(migraphx::shapes inputs) const override
{
if(inputs.size() != 2)
{
throw std::runtime_error("sscal_custom_op must have 2 input arguments");
}
if(inputs[0].lengths().size() != 1 || inputs[0].lengths()[0] != 1)
{
throw std::runtime_error("first input argument to sscal_custom_op must be a scalar");
}
if(inputs[1].lengths().size() != 1)
{
throw std::runtime_error(
"second input argument to sscal_custom_op must be a vector with dimension one");
}
return inputs.back();
}
};
int main(int argc, const char* argv[])
{
// computes ReLU(neg(x) * scale)
sscal_custom_op sscal_op;
migraphx::register_experimental_custom_op(sscal_op);
migraphx::program p;
migraphx::shape x_shape{migraphx_shape_float_type, {8192}};
migraphx::shape scale_shape{migraphx_shape_float_type, {1}};
migraphx::module m = p.get_main_module();
auto x = m.add_parameter("x", x_shape);
auto scale = m.add_parameter("scale", scale_shape);
auto neg_ins = m.add_instruction(migraphx::operation("neg"), {x});
auto custom_kernel =
m.add_instruction(migraphx::operation("sscal_custom_op"), {scale, neg_ins});
auto relu_ins = m.add_instruction(migraphx::operation("relu"), {custom_kernel});
m.add_return({relu_ins});
migraphx::compile_options options;
// set offload copy to true for GPUs
options.set_offload_copy();
p.compile(migraphx::target("gpu"), options);
migraphx::program_parameters pp;
std::vector<float> x_data(x_shape.bytes() / sizeof(x_shape.type()));
std::vector<float> scale_data{-1};
std::iota(x_data.begin(), x_data.end(), 0);
pp.add("x", migraphx::argument(x_shape, x_data.data()));
pp.add("scale", migraphx::argument(scale_shape, scale_data.data()));
auto results = p.eval(pp);
auto result = results[0];
std::vector<float> expected_result = x_data;
if(bool{result == migraphx::argument(x_shape, expected_result.data())})
{
std::cout << "Successfully executed custom rocBLAS kernel example\n";
}
else
{
std::cout << "Custom rocBLAS kernel example failed\n";
}
return 0;
}
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
...@@ -21,10 +22,10 @@ ...@@ -21,10 +22,10 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/apply_alpha_beta.hpp> #include <migraphx/json.hpp>
#include "models.hpp" #include "models.hpp"
namespace migraphx { namespace migraphx {
...@@ -34,173 +35,189 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -34,173 +35,189 @@ inline namespace MIGRAPHX_INLINE_NS {
migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size) migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); migraphx::module_ref mmain = p.get_main_module();
auto m0 = auto x_main_module_0 = mmain->add_literal(migraphx::abs(
mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {batch, 3, 224, 224}}); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1}}, 0)));
auto mx0 = mm->add_literal( auto x_main_module_1 = mmain->add_literal(migraphx::abs(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000}}, 0)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1}}, 1)));
auto mx1 = mm->add_literal( auto x_main_module_2 = mmain->add_literal(migraphx::abs(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000, 4096}}, 1)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1}}, 2)));
auto mx2 = mm->add_literal( auto x_input_1 = mmain->add_parameter(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096}}, 2)); "input.1", migraphx::shape{migraphx::shape::float_type, {batch, 3, 224, 224}});
auto mx3 = mm->add_literal( auto x_main_module_4 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096, 4096}}, 3)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096, 4096}}, 3));
auto mx4 = mm->add_literal( auto x_main_module_5 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096}}, 4)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096}}, 4));
auto mx5 = mm->add_literal( auto x_main_module_6 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096, 9216}}, 5)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096, 9216}}, 5));
auto mx6 = mm->add_literal( auto x_main_module_7 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 6)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096}}, 6));
auto mx7 = mm->add_literal(migraphx::generate_literal( auto x_main_module_8 = mmain->add_literal(
migraphx::shape{migraphx::shape::float_type, {256, 256, 3, 3}}, 7)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000, 4096}}, 7));
auto mx8 = mm->add_literal( auto x_main_module_9 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 8)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000}}, 8));
auto mx9 = mm->add_literal(migraphx::generate_literal( auto x_main_module_10 = mmain->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {256, 384, 3, 3}}, 9)); migraphx::shape{migraphx::shape::float_type, {256, 384, 3, 3}}, 9));
auto mx10 = mm->add_literal( auto x_main_module_11 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 10)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 10));
auto mx11 = mm->add_literal(migraphx::generate_literal( auto x_main_module_12 = mmain->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {384, 192, 3, 3}}, 11)); migraphx::shape{migraphx::shape::float_type, {384, 192, 3, 3}}, 11));
auto mx12 = mm->add_literal( auto x_main_module_13 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 12)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 12));
auto mx13 = mm->add_literal(migraphx::generate_literal( auto x_main_module_14 = mmain->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {192, 64, 5, 5}}, 13)); migraphx::shape{migraphx::shape::float_type, {192, 64, 5, 5}}, 13));
auto mx14 = mm->add_literal( auto x_main_module_15 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 14)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 14));
auto mx15 = mm->add_literal(migraphx::generate_literal( auto x_main_module_16 = mmain->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {64, 3, 11, 11}}, 15)); migraphx::shape{migraphx::shape::float_type, {256, 256, 3, 3}}, 15));
migraphx::op::convolution convolution16; auto x_main_module_17 = mmain->add_literal(
convolution16.padding = {2, 2}; migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 16));
convolution16.stride = {4, 4}; auto x_main_module_18 = mmain->add_literal(migraphx::generate_literal(
convolution16.dilation = {1, 1}; migraphx::shape{migraphx::shape::float_type, {64, 3, 11, 11}}, 17));
convolution16.group = 1; auto x_main_module_19 = mmain->add_literal(
auto mx16 = mm->add_instruction(convolution16, m0, mx15); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 18));
migraphx::op::broadcast broadcast17; auto x_main_module_20 = mmain->add_instruction(
broadcast17.axis = 1; migraphx::make_op(
broadcast17.broadcast_lens = {batch, 64, 55, 55}; "convolution",
auto mx17 = mm->add_instruction(broadcast17, mx14); migraphx::from_json_string(
migraphx::op::add add18; "{dilation:[1,1],group:1,padding:[2,2,2,2],padding_mode:0,stride:[4,4]}")),
auto mx18 = mm->add_instruction(add18, mx16, mx17); x_input_1,
migraphx::op::relu relu19; x_main_module_18);
auto mx19 = mm->add_instruction(relu19, mx18); auto x_main_module_21 = mmain->add_instruction(
migraphx::op::pooling pooling20; migraphx::make_op("broadcast",
pooling20.mode = migraphx::op::pooling_mode::max; migraphx::from_json_string("{axis:1,out_lens:[1,64,55,55]}")),
pooling20.padding = {0, 0}; x_main_module_19);
pooling20.stride = {2, 2}; auto x_main_module_22 =
pooling20.lengths = {3, 3}; mmain->add_instruction(migraphx::make_op("add"), x_main_module_20, x_main_module_21);
auto mx20 = mm->add_instruction(pooling20, mx19); auto x_main_module_23 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_22);
migraphx::op::convolution convolution21; auto x_main_module_24 = mmain->add_instruction(
convolution21.padding = {2, 2}; migraphx::make_op(
convolution21.stride = {1, 1}; "pooling",
convolution21.dilation = {1, 1}; migraphx::from_json_string(
convolution21.group = 1; "{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}")),
auto mx21 = mm->add_instruction(convolution21, mx20, mx13); x_main_module_23);
migraphx::op::broadcast broadcast22; auto x_main_module_25 = mmain->add_instruction(
broadcast22.axis = 1; migraphx::make_op(
broadcast22.broadcast_lens = {batch, 192, 27, 27}; "convolution",
auto mx22 = mm->add_instruction(broadcast22, mx12); migraphx::from_json_string(
migraphx::op::add add23; "{dilation:[1,1],group:1,padding:[2,2,2,2],padding_mode:0,stride:[1,1]}")),
auto mx23 = mm->add_instruction(add23, mx21, mx22); x_main_module_24,
migraphx::op::relu relu24; x_main_module_14);
auto mx24 = mm->add_instruction(relu24, mx23); auto x_main_module_26 = mmain->add_instruction(
migraphx::op::pooling pooling25; migraphx::make_op("broadcast",
pooling25.mode = migraphx::op::pooling_mode::max; migraphx::from_json_string("{axis:1,out_lens:[1,192,27,27]}")),
pooling25.padding = {0, 0}; x_main_module_15);
pooling25.stride = {2, 2}; auto x_main_module_27 =
pooling25.lengths = {3, 3}; mmain->add_instruction(migraphx::make_op("add"), x_main_module_25, x_main_module_26);
auto mx25 = mm->add_instruction(pooling25, mx24); auto x_main_module_28 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_27);
migraphx::op::convolution convolution26; auto x_main_module_29 = mmain->add_instruction(
convolution26.padding = {1, 1}; migraphx::make_op(
convolution26.stride = {1, 1}; "pooling",
convolution26.dilation = {1, 1}; migraphx::from_json_string(
convolution26.group = 1; "{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}")),
auto mx26 = mm->add_instruction(convolution26, mx25, mx11); x_main_module_28);
migraphx::op::broadcast broadcast27; auto x_main_module_30 = mmain->add_instruction(
broadcast27.axis = 1; migraphx::make_op(
broadcast27.broadcast_lens = {batch, 384, 13, 13}; "convolution",
auto mx27 = mm->add_instruction(broadcast27, mx10); migraphx::from_json_string(
migraphx::op::add add28; "{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,1]}")),
auto mx28 = mm->add_instruction(add28, mx26, mx27); x_main_module_29,
migraphx::op::relu relu29; x_main_module_12);
auto mx29 = mm->add_instruction(relu29, mx28); auto x_main_module_31 = mmain->add_instruction(
migraphx::op::convolution convolution30; migraphx::make_op("broadcast",
convolution30.padding = {1, 1}; migraphx::from_json_string("{axis:1,out_lens:[1,384,13,13]}")),
convolution30.stride = {1, 1}; x_main_module_13);
convolution30.dilation = {1, 1}; auto x_main_module_32 =
convolution30.group = 1; mmain->add_instruction(migraphx::make_op("add"), x_main_module_30, x_main_module_31);
auto mx30 = mm->add_instruction(convolution30, mx29, mx9); auto x_main_module_33 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_32);
migraphx::op::broadcast broadcast31; auto x_main_module_34 = mmain->add_instruction(
broadcast31.axis = 1; migraphx::make_op(
broadcast31.broadcast_lens = {batch, 256, 13, 13}; "convolution",
auto mx31 = mm->add_instruction(broadcast31, mx8); migraphx::from_json_string(
migraphx::op::add add32; "{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,1]}")),
auto mx32 = mm->add_instruction(add32, mx30, mx31); x_main_module_33,
migraphx::op::relu relu33; x_main_module_10);
auto mx33 = mm->add_instruction(relu33, mx32); auto x_main_module_35 = mmain->add_instruction(
migraphx::op::convolution convolution34; migraphx::make_op("broadcast",
convolution34.padding = {1, 1}; migraphx::from_json_string("{axis:1,out_lens:[1,256,13,13]}")),
convolution34.stride = {1, 1}; x_main_module_11);
convolution34.dilation = {1, 1}; auto x_main_module_36 =
convolution34.group = 1; mmain->add_instruction(migraphx::make_op("add"), x_main_module_34, x_main_module_35);
auto mx34 = mm->add_instruction(convolution34, mx33, mx7); auto x_main_module_37 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_36);
migraphx::op::broadcast broadcast35; auto x_main_module_38 = mmain->add_instruction(
broadcast35.axis = 1; migraphx::make_op(
broadcast35.broadcast_lens = {batch, 256, 13, 13}; "convolution",
auto mx35 = mm->add_instruction(broadcast35, mx6); migraphx::from_json_string(
migraphx::op::add add36; "{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,1]}")),
auto mx36 = mm->add_instruction(add36, mx34, mx35); x_main_module_37,
migraphx::op::relu relu37; x_main_module_16);
auto mx37 = mm->add_instruction(relu37, mx36); auto x_main_module_39 = mmain->add_instruction(
migraphx::op::pooling pooling38; migraphx::make_op("broadcast",
pooling38.mode = migraphx::op::pooling_mode::max; migraphx::from_json_string("{axis:1,out_lens:[1,256,13,13]}")),
pooling38.padding = {0, 0}; x_main_module_17);
pooling38.stride = {2, 2}; auto x_main_module_40 =
pooling38.lengths = {3, 3}; mmain->add_instruction(migraphx::make_op("add"), x_main_module_38, x_main_module_39);
auto mx38 = mm->add_instruction(pooling38, mx37); auto x_main_module_41 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_40);
migraphx::op::flatten flatten39; auto x_main_module_42 = mmain->add_instruction(
flatten39.axis = 1; migraphx::make_op(
auto mx39 = mm->add_instruction(flatten39, mx38); "pooling",
migraphx::op::identity identity40; migraphx::from_json_string(
auto mx40 = mm->add_instruction(identity40, mx39); "{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}")),
migraphx::op::transpose transpose41; x_main_module_41);
transpose41.dims = {1, 0}; auto x_main_module_43 = mmain->add_instruction(
auto mx41 = mm->add_instruction(transpose41, mx5); migraphx::make_op("reshape", migraphx::from_json_string("{dims:[1,9216]}")),
migraphx::op::multibroadcast multibroadcast42; x_main_module_42);
multibroadcast42.output_lens = {batch, 4096}; auto x_main_module_44 = mmain->add_instruction(
auto mx42 = mm->add_instruction(multibroadcast42, mx4); migraphx::make_op("transpose", migraphx::from_json_string("{permutation:[1,0]}")),
float dot43_alpha = 1; x_main_module_6);
float dot43_beta = 1; auto x_main_module_45 =
auto mx43 = migraphx::add_apply_alpha_beta( mmain->add_instruction(migraphx::make_op("dot"), x_main_module_43, x_main_module_44);
*mm, {mx40, mx41, mx42}, migraphx::make_op("dot"), dot43_alpha, dot43_beta); auto x_main_module_46 = mmain->add_instruction(
migraphx::op::relu relu44; migraphx::make_op("multibroadcast", migraphx::from_json_string("{out_lens:[1,4096]}")),
auto mx44 = mm->add_instruction(relu44, mx43); x_main_module_7);
migraphx::op::identity identity45; auto x_main_module_47 = mmain->add_instruction(
auto mx45 = mm->add_instruction(identity45, mx44); migraphx::make_op("multibroadcast", migraphx::from_json_string("{out_lens:[1,4096]}")),
migraphx::op::transpose transpose46; x_main_module_2);
transpose46.dims = {1, 0}; auto x_main_module_48 =
auto mx46 = mm->add_instruction(transpose46, mx3); mmain->add_instruction(migraphx::make_op("mul"), x_main_module_46, x_main_module_47);
migraphx::op::multibroadcast multibroadcast47; auto x_main_module_49 =
multibroadcast47.output_lens = {batch, 4096}; mmain->add_instruction(migraphx::make_op("add"), x_main_module_45, x_main_module_48);
auto mx47 = mm->add_instruction(multibroadcast47, mx2); auto x_main_module_50 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_49);
float dot48_alpha = 1; auto x_main_module_51 = mmain->add_instruction(
float dot48_beta = 1; migraphx::make_op("transpose", migraphx::from_json_string("{permutation:[1,0]}")),
auto mx48 = migraphx::add_apply_alpha_beta( x_main_module_4);
*mm, {mx45, mx46, mx47}, migraphx::make_op("dot"), dot48_alpha, dot48_beta); auto x_main_module_52 =
migraphx::op::relu relu49; mmain->add_instruction(migraphx::make_op("dot"), x_main_module_50, x_main_module_51);
auto mx49 = mm->add_instruction(relu49, mx48); auto x_main_module_53 = mmain->add_instruction(
migraphx::op::transpose transpose50; migraphx::make_op("multibroadcast", migraphx::from_json_string("{out_lens:[1,4096]}")),
transpose50.dims = {1, 0}; x_main_module_5);
auto mx50 = mm->add_instruction(transpose50, mx1); auto x_main_module_54 = mmain->add_instruction(
migraphx::op::multibroadcast multibroadcast51; migraphx::make_op("multibroadcast", migraphx::from_json_string("{out_lens:[1,4096]}")),
multibroadcast51.output_lens = {batch, 1000}; x_main_module_1);
auto mx51 = mm->add_instruction(multibroadcast51, mx0); auto x_main_module_55 =
float dot52_alpha = 1; mmain->add_instruction(migraphx::make_op("mul"), x_main_module_53, x_main_module_54);
float dot52_beta = 1; auto x_main_module_56 =
migraphx::add_apply_alpha_beta( mmain->add_instruction(migraphx::make_op("add"), x_main_module_52, x_main_module_55);
*mm, {mx49, mx50, mx51}, migraphx::make_op("dot"), dot52_alpha, dot52_beta); auto x_main_module_57 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_56);
auto x_main_module_58 = mmain->add_instruction(
migraphx::make_op("transpose", migraphx::from_json_string("{permutation:[1,0]}")),
x_main_module_8);
auto x_main_module_59 =
mmain->add_instruction(migraphx::make_op("dot"), x_main_module_57, x_main_module_58);
auto x_main_module_60 = mmain->add_instruction(
migraphx::make_op("multibroadcast", migraphx::from_json_string("{out_lens:[1,1000]}")),
x_main_module_9);
auto x_main_module_61 = mmain->add_instruction(
migraphx::make_op("multibroadcast", migraphx::from_json_string("{out_lens:[1,1000]}")),
x_main_module_0);
auto x_main_module_62 =
mmain->add_instruction(migraphx::make_op("mul"), x_main_module_60, x_main_module_61);
auto x_main_module_63 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_59, x_main_module_62);
mmain->add_return({x_main_module_63});
return p; return p;
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace driver } // namespace driver
} // namespace migraphx } // namespace migraphx
This diff is collapsed.
...@@ -210,6 +210,9 @@ struct loader ...@@ -210,6 +210,9 @@ struct loader
auto last = std::prev(mm->end(), trim); auto last = std::prev(mm->end(), trim);
mm->remove_instructions(last, mm->end()); mm->remove_instructions(last, mm->end());
} }
// Remove unused variable when exporting to cpp
if(output_type == "cpp")
migraphx::run_passes(*p.get_main_module(), {migraphx::dead_code_elimination{}});
if(optimize) if(optimize)
{ {
migraphx::run_passes(*p.get_main_module(), migraphx::run_passes(*p.get_main_module(),
......
This diff is collapsed.
...@@ -81,8 +81,9 @@ struct basic_iota_iterator ...@@ -81,8 +81,9 @@ struct basic_iota_iterator
index--; index--;
return it; return it;
} }
// TODO: operator->
reference operator*() const { return f(index); } reference operator*() const { return f(index); }
pointer operator->() const { return &f(index); }
reference operator[](int n) const { return f(index + n); }
}; };
template <class T, class F> template <class T, class F>
......
...@@ -562,6 +562,11 @@ MIGRAPHX_BASIC_MATCHER(is_unused, const matcher_context& ctx, instruction_ref in ...@@ -562,6 +562,11 @@ MIGRAPHX_BASIC_MATCHER(is_unused, const matcher_context& ctx, instruction_ref in
return nullopt; return nullopt;
} }
MIGRAPHX_PRED_MATCHER(broadcast, instruction_ref ins)
{
return contains({"broadcast", "multibroadcast"}, ins->name());
}
template <class... Ms> template <class... Ms>
auto skip(Ms... ms) auto skip(Ms... ms)
{ {
...@@ -811,8 +816,7 @@ inline auto has_attribute(const std::string& name) ...@@ -811,8 +816,7 @@ inline auto has_attribute(const std::string& name)
template <class... Ms> template <class... Ms>
auto pointwise(Ms... ms) auto pointwise(Ms... ms)
{ {
return match::has_attribute("pointwise")(match::any_of(match::nargs(1), match::nargs(2)), return match::has_attribute("pointwise")(ms...);
ms...);
} }
} // namespace match } // namespace match
......
...@@ -179,11 +179,13 @@ struct module ...@@ -179,11 +179,13 @@ struct module
void print_cpp(std::ostream& os) const; void print_cpp(std::ostream& os) const;
std::unordered_map<instruction_ref, std::string> std::unordered_map<instruction_ref, std::string>
print_cpp(std::ostream& os, std::unordered_map<instruction_ref, std::string> names) const; print_cpp(std::ostream& os,
const std::string& mname,
std::unordered_map<instruction_ref, std::string> names) const;
void annotate(std::ostream& os, std::function<void(instruction_ref)> a) const; void annotate(std::ostream& os, std::function<void(instruction_ref)> a) const;
std::vector<module_ref> get_sub_modules() const; std::vector<module_ref> get_sub_modules(bool shallow = false) const;
module& sort(); module& sort();
ins_dep_map calc_implicit_deps() const; ins_dep_map calc_implicit_deps() const;
......
...@@ -56,14 +56,21 @@ struct nonmaxsuppression ...@@ -56,14 +56,21 @@ struct nonmaxsuppression
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
// requires at least 2 inputs // requires at least 2 inputs
check_shapes{inputs, *this}.standard();
check_shapes{{inputs.at(0), inputs.at(1)}, *this}.only_dims(3); check_shapes{{inputs.at(0), inputs.at(1)}, *this}.only_dims(3);
auto lens = inputs.front().lens(); auto lens = inputs.front().lens();
// check input shape // check input shape
if(lens[1] != inputs.at(1).lens()[2]) if(lens[1] != inputs.at(1).lens()[2])
{ {
MIGRAPHX_THROW("NonMaxSuppression: dimension mismatch between first and second input!"); MIGRAPHX_THROW(
"NonMaxSuppression: spatial dimension mismatch between boxes and scores input");
}
// check batch sizes
if(lens[0] != inputs.at(1).lens()[0])
{
MIGRAPHX_THROW(
"NonMaxSuppression: number of batches mismatch between boxes and scores input");
} }
std::vector<int64_t> out_lens(2); std::vector<int64_t> out_lens(2);
...@@ -74,8 +81,8 @@ struct nonmaxsuppression ...@@ -74,8 +81,8 @@ struct nonmaxsuppression
struct box struct box
{ {
std::array<float, 2> x; std::array<double, 2> x;
std::array<float, 2> y; std::array<double, 2> y;
void sort() void sort()
{ {
...@@ -83,9 +90,9 @@ struct nonmaxsuppression ...@@ -83,9 +90,9 @@ struct nonmaxsuppression
std::sort(y.begin(), y.end()); std::sort(y.begin(), y.end());
} }
std::array<float, 2>& operator[](std::size_t i) { return i == 0 ? x : y; } std::array<double, 2>& operator[](std::size_t i) { return i == 0 ? x : y; }
float area() const double area() const
{ {
assert(std::is_sorted(x.begin(), x.end())); assert(std::is_sorted(x.begin(), x.end()));
assert(std::is_sorted(y.begin(), y.end())); assert(std::is_sorted(y.begin(), y.end()));
...@@ -94,29 +101,29 @@ struct nonmaxsuppression ...@@ -94,29 +101,29 @@ struct nonmaxsuppression
}; };
template <class T> template <class T>
box batch_box(const T* boxes, std::size_t bidx) const box batch_box(T boxes, std::size_t box_idx) const
{ {
box result{}; box result{};
const T* start = boxes + 4 * bidx; auto start = boxes + 4 * box_idx;
if(center_point_box) if(center_point_box)
{ {
float half_width = start[2] / 2.0f; double half_width = start[2] / 2.0;
float half_height = start[3] / 2.0f; double half_height = start[3] / 2.0;
float x_center = start[0]; double x_center = start[0];
float y_center = start[1]; double y_center = start[1];
result.x = {x_center - half_width, x_center + half_width}; result.x = {x_center - half_width, x_center + half_width};
result.y = {y_center - half_height, y_center + half_height}; result.y = {y_center - half_height, y_center + half_height};
} }
else else
{ {
result.x = {start[1], start[3]}; result.x = {static_cast<double>(start[1]), static_cast<double>(start[3])};
result.y = {start[0], start[2]}; result.y = {static_cast<double>(start[0]), static_cast<double>(start[2])};
} }
return result; return result;
} }
inline bool suppress_by_iou(box b1, box b2, float iou_threshold) const inline bool suppress_by_iou(box b1, box b2, double iou_threshold) const
{ {
b1.sort(); b1.sort();
b2.sort(); b2.sort();
...@@ -128,7 +135,7 @@ struct nonmaxsuppression ...@@ -128,7 +135,7 @@ struct nonmaxsuppression
intersection[i][1] = std::min(b1[i][1], b2[i][1]); intersection[i][1] = std::min(b1[i][1], b2[i][1]);
} }
std::vector<std::array<float, 2>> bbox = {intersection.x, intersection.y}; std::vector<std::array<double, 2>> bbox = {intersection.x, intersection.y};
if(std::any_of(bbox.begin(), bbox.end(), [](auto bx) { if(std::any_of(bbox.begin(), bbox.end(), [](auto bx) {
return not std::is_sorted(bx.begin(), bx.end()); return not std::is_sorted(bx.begin(), bx.end());
})) }))
...@@ -136,115 +143,124 @@ struct nonmaxsuppression ...@@ -136,115 +143,124 @@ struct nonmaxsuppression
return false; return false;
} }
const float area1 = b1.area(); const double area1 = b1.area();
const float area2 = b2.area(); const double area2 = b2.area();
const float intersection_area = intersection.area(); const double intersection_area = intersection.area();
const float union_area = area1 + area2 - intersection_area; const double union_area = area1 + area2 - intersection_area;
if(area1 <= .0f or area2 <= .0f or union_area <= .0f) if(area1 <= .0f or area2 <= .0f or union_area <= .0f)
{ {
return false; return false;
} }
const float intersection_over_union = intersection_area / union_area; const double intersection_over_union = intersection_area / union_area;
return intersection_over_union > iou_threshold; return intersection_over_union > iou_threshold;
} }
argument compute(const shape& output_shape, std::vector<argument> args) const // filter boxes below score_threshold
template <class T>
std::priority_queue<std::pair<double, int64_t>>
filter_boxes_by_score(T scores_start, std::size_t num_boxes, double score_threshold) const
{ {
argument result{output_shape}; std::priority_queue<std::pair<double, int64_t>> boxes_heap;
auto insert_to_boxes_heap =
result.visit([&](auto out) { std::fill(out.begin(), out.end(), 0); }); make_function_output_iterator([&](const auto& x) { boxes_heap.push(x); });
int64_t box_idx = 0;
std::size_t max_output_boxes_per_class = 0; transform_if(
float iou_threshold = 0.0f; scores_start,
float score_threshold = 0.0f; scores_start + num_boxes,
insert_to_boxes_heap,
if(args.size() > 2) [&](auto sc) {
{ box_idx++;
max_output_boxes_per_class = args.at(2).at<std::size_t>(); return sc >= score_threshold;
} },
// max_output_boxes_per_class is 0, no output [&](auto sc) { return std::make_pair(sc, box_idx - 1); });
if(max_output_boxes_per_class == 0) return boxes_heap;
{ }
return result;
}
if(args.size() > 3)
{
iou_threshold = args.at(3).at<float>();
}
if(args.size() > 4)
{
score_threshold = args.at(4).at<float>();
}
const auto& lens = args.at(1).get_shape().lens();
auto batch_num = lens[0];
auto class_num = lens[1];
auto box_num = args.at(0).get_shape().lens()[1];
std::vector<std::pair<float, int64_t>> selected_boxes_inside_class; template <class Output, class Boxes, class Scores>
void compute_nms(Output output,
Boxes boxes,
Scores scores,
const shape& output_shape,
std::size_t max_output_boxes_per_class,
double iou_threshold,
double score_threshold) const
{
std::fill(output.begin(), output.end(), 0);
const auto& lens = scores.get_shape().lens();
const auto num_batches = lens[0];
const auto num_classes = lens[1];
const auto num_boxes = lens[2];
// boxes of a class with NMS applied [score, index]
std::vector<std::pair<double, int64_t>> selected_boxes_inside_class;
std::vector<int64_t> selected_indices; std::vector<int64_t> selected_indices;
selected_boxes_inside_class.reserve(output_shape.elements()); selected_boxes_inside_class.reserve(output_shape.elements());
// iterate over batches and classes
auto scores = make_view<float>(args.at(1).get_shape(), args.at(1).cast<float>()); shape comp_s{shape::double_type, {num_batches, num_classes}};
const float* boxes = args.at(0).cast<float>();
shape comp_s{shape::float_type, {batch_num, class_num}};
shape_for_each(comp_s, [&](auto idx) { shape_for_each(comp_s, [&](auto idx) {
auto bidx = idx[0]; auto batch_idx = idx[0];
auto cidx = idx[1]; auto class_idx = idx[1];
// index offset for this class
std::size_t score_offset = (bidx * class_num + cidx) * box_num; auto scores_start = scores.begin() + (batch_idx * num_classes + class_idx) * num_boxes;
const float* batch_boxes = boxes + bidx * box_num * 4; // iterator to first value of this batch
std::priority_queue<std::pair<float, int64_t>> sorted_boxes; auto batch_boxes_start = boxes.begin() + batch_idx * num_boxes * 4;
auto insert_to_sorted_boxes = auto boxes_heap = filter_boxes_by_score(scores_start, num_boxes, score_threshold);
make_function_output_iterator([&](const auto& x) { sorted_boxes.push(x); });
int64_t box_idx = 0;
transform_if(
scores.begin() + score_offset,
scores.begin() + score_offset + box_num,
insert_to_sorted_boxes,
[&](auto sc) {
box_idx++;
return sc >= score_threshold;
},
[&](auto sc) { return std::make_pair(sc, box_idx - 1); });
selected_boxes_inside_class.clear(); selected_boxes_inside_class.clear();
// Get the next box with top score, filter by iou_threshold // Get the next box with top score, filter by iou_threshold
while(!sorted_boxes.empty() && while(!boxes_heap.empty() &&
selected_boxes_inside_class.size() < max_output_boxes_per_class) selected_boxes_inside_class.size() < max_output_boxes_per_class)
{ {
const std::pair<float, int64_t>& next_top_score = sorted_boxes.top(); // Check with existing selected boxes for this class, remove box if it
// exceeds the IOU (Intersection Over Union) threshold
// Check with existing selected boxes for this class, suppress if exceed the IOU const auto next_top_score = boxes_heap.top();
// (Intersection Over Union) threshold bool not_selected =
bool not_selected = std::any_of( std::any_of(selected_boxes_inside_class.begin(),
selected_boxes_inside_class.begin(), selected_boxes_inside_class.end(),
selected_boxes_inside_class.end(), [&](auto selected_index) {
[&](auto selected_index) { return this->suppress_by_iou(
return this->suppress_by_iou(batch_box(batch_boxes, next_top_score.second), batch_box(batch_boxes_start, next_top_score.second),
batch_box(batch_boxes, selected_index.second), batch_box(batch_boxes_start, selected_index.second),
iou_threshold); iou_threshold);
}); });
if(not not_selected) if(not not_selected)
{ {
selected_boxes_inside_class.push_back(next_top_score); selected_boxes_inside_class.push_back(next_top_score);
selected_indices.push_back(bidx); selected_indices.push_back(batch_idx);
selected_indices.push_back(cidx); selected_indices.push_back(class_idx);
selected_indices.push_back(next_top_score.second); selected_indices.push_back(next_top_score.second);
} }
sorted_boxes.pop(); boxes_heap.pop();
} }
}); });
std::copy(selected_indices.begin(), selected_indices.end(), output.begin());
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
result.visit([&](auto out) { std::size_t max_output_boxes_per_class =
std::copy(selected_indices.begin(), selected_indices.end(), out.begin()); (args.size() > 2) ? (args.at(2).at<std::size_t>()) : 0;
if(max_output_boxes_per_class == 0)
{
return result;
}
double iou_threshold = (args.size() > 3) ? (args.at(3).at<double>()) : 0.0f;
double score_threshold = (args.size() > 4) ? (args.at(4).at<double>()) : 0.0f;
result.visit([&](auto output) {
visit_all(args[0], args[1])([&](auto boxes, auto scores) {
compute_nms(output,
boxes,
scores,
output_shape,
max_output_boxes_per_class,
iou_threshold,
score_threshold);
});
}); });
return result; return result;
......
...@@ -38,6 +38,7 @@ struct module_pass_manager ...@@ -38,6 +38,7 @@ struct module_pass_manager
module_pass_manager(const module_pass_manager&) = delete; module_pass_manager(const module_pass_manager&) = delete;
virtual module& get_module() = 0; virtual module& get_module() = 0;
virtual module* create_module(const std::string& name) = 0; virtual module* create_module(const std::string& name) = 0;
virtual module* get_common_parent() = 0;
virtual void run_pass(const pass& p) = 0; virtual void run_pass(const pass& p) = 0;
protected: protected:
......
...@@ -132,6 +132,8 @@ struct program ...@@ -132,6 +132,8 @@ struct program
std::vector<const module*> get_modules() const; std::vector<const module*> get_modules() const;
std::vector<module*> get_modules(); std::vector<module*> get_modules();
std::unordered_multimap<module_ref, module_ref> get_module_tree();
void remove_module(const std::string& name); void remove_module(const std::string& name);
void remove_unused_modules(); void remove_unused_modules();
......
...@@ -216,6 +216,12 @@ bool equal(R1&& r1, R2&& r2, Predicate... pred) ...@@ -216,6 +216,12 @@ bool equal(R1&& r1, R2&& r2, Predicate... pred)
return std::equal(r1.begin(), r1.end(), r2.begin(), r2.end(), pred...); return std::equal(r1.begin(), r1.end(), r2.begin(), r2.end(), pred...);
} }
template <class Range>
auto distance(Range&& r)
{
return std::distance(r.begin(), r.end());
}
template <class R> template <class R>
using range_value = std::decay_t<decltype(*std::declval<R>().begin())>; using range_value = std::decay_t<decltype(*std::declval<R>().begin())>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment