Unverified Commit b30755ee authored by WilliamKyle's avatar WilliamKyle Committed by GitHub
Browse files

[Feature] Add rotated_feature_align cpu & onnxruntime implementation (#1878)

* add rotated_feature_align cpu implementation

* add rotated_feature_align onnxruntime implementation

* Update code for advices from grimoire

Remove useless comment from mmcv/ops/csrc/pytorch/cpu/rotated_feature_align.cpp

Replace ambiguous function name atomicAdd in mmcv/ops/csrc/pytorch/cpu/rotated_feature_align.cpp

Simplify unit test with parameter in tests/test_ops/test_rotated_feature_align.py

Use fma in interpolate in mmcv/ops/csrc/onnxruntime/cpu/rotated_feature_align.cpp mmcv/ops/csrc/pytorch/cpu/rotated_feature_align.cpp

* Inline the function to reduce the overhead of the function call

Use fma in interpolate
parent 4efec732
......@@ -10,6 +10,7 @@
#include "reduce_ops.h"
#include "roi_align.h"
#include "roi_align_rotated.h"
#include "rotated_feature_align.h"
#include "soft_nms.h"
const char *c_MMCVOpDomain = "mmcv";
......@@ -17,6 +18,7 @@ SoftNmsOp c_SoftNmsOp;
NmsOp c_NmsOp;
MMCVRoiAlignCustomOp c_MMCVRoiAlignCustomOp;
MMCVRoIAlignRotatedCustomOp c_MMCVRoIAlignRotatedCustomOp;
MMCVRotatedFeatureAlignCustomOp c_MMCVRotatedFeatureAlignCustomOp;
GridSampleOp c_GridSampleOp;
MMCVCumMaxCustomOp c_MMCVCumMaxCustomOp;
MMCVCumMinCustomOp c_MMCVCumMinCustomOp;
......@@ -77,5 +79,10 @@ OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options,
return status;
}
if (auto status = ortApi->CustomOpDomain_Add(
domain, &c_MMCVRotatedFeatureAlignCustomOp)) {
return status;
}
return ortApi->AddCustomOpDomain(options, domain);
}
// Modified from
// https://github.com/SJTU-Thinklab-Det/r3det-on-mmdetection/blob/master/mmdet/ops/fr/src/feature_refine_kernel.cu
#include "rotated_feature_align.h"
#include "../ort_mmcv_utils.h"
template <typename T>
T bilinear_interpolate(const T *input, const int height, const int width, T y,
T x, const int index /* index for debug only*/) {
// deal with cases that inverse elements are out of feature map boundary
if (y < -1.0 || y > height || x < -1.0 || x > width) return 0;
if (y <= 0) y = 0;
if (x <= 0) x = 0;
int y_low = (int)y;
int x_low = (int)x;
int y_high;
int x_high;
if (y_low >= height - 1) {
y_high = y_low = height - 1;
y = (T)y_low;
} else {
y_high = y_low + 1;
}
if (x_low >= width - 1) {
x_high = x_low = width - 1;
x = (T)x_low;
} else {
x_high = x_low + 1;
}
T ly = y - y_low;
T lx = x - x_low;
T hy = 1. - ly, hx = 1. - lx;
// do bilinear interpolation
T v1 = input[int(fma(y_low, width, x_low))];
T v2 = input[int(fma(y_low, width, x_high))];
T v3 = input[int(fma(y_high, width, x_low))];
T v4 = input[int(fma(y_high, width, x_high))];
T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
template <typename scalar_t>
void rotated_feature_align_forward_cpu_kernel(
const int nthreads, const int points, const scalar_t *bottom_data,
const scalar_t *best_bboxes, const scalar_t spatial_scale,
const int channels, const int height, const int width, scalar_t *top_data) {
for (int index = 0; index < nthreads; index++) {
int w = index % width;
int h = (index / width) % height;
int c = (index / width / height) % channels;
int n = index / width / height / channels;
const scalar_t *bbox_offset =
best_bboxes + ((n * height + h) * width + w) * 5;
scalar_t roi_y = bbox_offset[0] * spatial_scale;
scalar_t roi_x = bbox_offset[1] * spatial_scale;
scalar_t px[5] = {roi_x, 0, 0, 0, 0};
scalar_t py[5] = {roi_y, 0, 0, 0, 0};
if (points > 1) {
scalar_t roi_w = bbox_offset[2] * spatial_scale;
scalar_t roi_h = bbox_offset[3] * spatial_scale;
scalar_t roi_a = bbox_offset[4];
scalar_t w_2 = roi_w / 2, h_2 = roi_h / 2;
scalar_t cosa = cosf(roi_a), sina = sinf(roi_a);
scalar_t wx = cosa * w_2, wy = sina * w_2;
scalar_t hx = -sina * h_2, hy = cosa * h_2;
px[1] = roi_x + wx + hx;
py[1] = roi_y + wy + hy;
px[2] = roi_x - wx + hx;
py[2] = roi_y - wy + hy;
px[3] = roi_x - wx - hx;
py[3] = roi_y - wy - hy;
px[4] = roi_x + wx - hx;
py[4] = roi_y + wy - hy;
}
const scalar_t *offset_bottom_data =
bottom_data + (n * channels + c) * height * width;
scalar_t output_val = bottom_data[index];
for (int i = 0; i < points; i++) {
output_val += bilinear_interpolate<scalar_t>(offset_bottom_data, height,
width, py[i], px[i], i);
}
top_data[index] = output_val;
}
}
void MMCVRotatedFeatureAlignKernel::Compute(OrtKernelContext *context) {
// Setup inputs
const OrtValue *input_features = ort_.KernelContext_GetInput(context, 0);
const float *features_data = reinterpret_cast<const float *>(
ort_.GetTensorData<float>(input_features));
const OrtValue *input_best_rbboxes = ort_.KernelContext_GetInput(context, 1);
const float *best_rbboxes = reinterpret_cast<const float *>(
ort_.GetTensorData<const float *>(input_best_rbboxes));
// Setup output
OrtTensorDimensions out_dimensions(ort_, input_features);
int batch_size = out_dimensions.data()[0];
int input_channels = out_dimensions.data()[1];
int input_height = out_dimensions.data()[2];
int input_width = out_dimensions.data()[3];
OrtValue *output = ort_.KernelContext_GetOutput(
context, 0, out_dimensions.data(), out_dimensions.size());
float *out = ort_.GetTensorMutableData<float>(output);
OrtTensorTypeAndShapeInfo *output_info = ort_.GetTensorTypeAndShape(output);
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
// TODO: forward here
int output_size = out_dimensions.data()[0];
for (auto i = 1; i < out_dimensions.size(); ++i) {
output_size *= out_dimensions.data()[i];
}
rotated_feature_align_forward_cpu_kernel<float>(
output_size, points_, features_data, best_rbboxes, spatial_scale_,
input_channels, input_height, input_width, out);
}
#ifndef ONNXRUNTIME_ROTATED_FEATURE_ALIGN_H
#define ONNXRUNTIME_ROTATED_FEATURE_ALIGN_H
#include <onnxruntime_cxx_api.h>
#include <cmath>
struct MMCVRotatedFeatureAlignKernel {
public:
MMCVRotatedFeatureAlignKernel(Ort::CustomOpApi ort, const OrtKernelInfo* info)
: ort_(ort) {
spatial_scale_ = ort_.KernelInfoGetAttribute<float>(info, "spatial_scale");
points_ = ort_.KernelInfoGetAttribute<int64_t>(info, "points");
}
void Compute(OrtKernelContext* context);
private:
Ort::CustomOpApi ort_;
float spatial_scale_;
int points_;
};
struct MMCVRotatedFeatureAlignCustomOp
: Ort::CustomOpBase<MMCVRotatedFeatureAlignCustomOp,
MMCVRotatedFeatureAlignKernel> {
void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) const {
return new MMCVRotatedFeatureAlignKernel(api, info);
}
const char* GetName() const { return "MMCVRotatedFeatureAlign"; }
size_t GetInputTypeCount() const { return 2; }
ONNXTensorElementDataType GetInputType(size_t) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
}
size_t GetOutputTypeCount() const { return 1; }
ONNXTensorElementDataType GetOutputType(size_t) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
}
// force cpu
const char* GetExecutionProviderType() const {
return "CPUExecutionProvider";
}
};
#endif // ONNXRUNTIME_ROTATED_FEATURE_ALIGN_H
// modified from
// https://github.com/SJTU-Thinklab-Det/r3det-on-mmdetection/blob/master/mmdet/ops/fr/src/feature_refine_kernel.cu
#include "pytorch_cpp_helper.hpp"
#include "pytorch_device_registry.hpp"
template <typename T>
T bilinear_interpolate(const T* input, const int height, const int width, T y,
T x, const int index /* index for debug only*/) {
// deal with cases that inverse elements are out of feature map boundary
if (y < -1.0 || y > height || x < -1.0 || x > width) return 0;
if (y <= 0) y = 0;
if (x <= 0) x = 0;
int y_low = (int)y;
int x_low = (int)x;
int y_high;
int x_high;
if (y_low >= height - 1) {
y_high = y_low = height - 1;
y = (T)y_low;
} else {
y_high = y_low + 1;
}
if (x_low >= width - 1) {
x_high = x_low = width - 1;
x = (T)x_low;
} else {
x_high = x_low + 1;
}
T ly = y - y_low;
T lx = x - x_low;
// do bilinear interpolation
T v1 = input[y_low * width + x_low];
T v2 = input[y_low * width + x_high];
T v3 = input[y_high * width + x_low];
T v4 = input[y_high * width + x_high];
const T v_low = fma(v2 - v1, lx, v1);
const T v_high = fma(v4 - v3, lx, v3);
const T val = fma(v_high - v_low, ly, v_low);
return val;
}
template <typename scalar_t>
void rotated_feature_align_forward_cpu_kernel(
const int nthreads, const int points, const scalar_t* bottom_data,
const scalar_t* best_bboxes, const scalar_t spatial_scale,
const int channels, const int height, const int width, scalar_t* top_data) {
for (int index = 0; index < nthreads; index++) {
int w = index % width;
int h = (index / width) % height;
int c = (index / width / height) % channels;
int n = index / width / height / channels;
const scalar_t* bbox_offset =
best_bboxes + ((n * height + h) * width + w) * 5;
scalar_t roi_y = bbox_offset[0] * spatial_scale;
scalar_t roi_x = bbox_offset[1] * spatial_scale;
scalar_t px[5] = {roi_x, 0, 0, 0, 0};
scalar_t py[5] = {roi_y, 0, 0, 0, 0};
if (points > 1) {
scalar_t roi_w = bbox_offset[2] * spatial_scale;
scalar_t roi_h = bbox_offset[3] * spatial_scale;
scalar_t roi_a = bbox_offset[4];
scalar_t w_2 = roi_w / 2, h_2 = roi_h / 2;
scalar_t cosa = cosf(roi_a), sina = sinf(roi_a);
scalar_t wx = cosa * w_2, wy = sina * w_2;
scalar_t hx = -sina * h_2, hy = cosa * h_2;
px[1] = roi_x + wx + hx;
py[1] = roi_y + wy + hy;
px[2] = roi_x - wx + hx;
py[2] = roi_y - wy + hy;
px[3] = roi_x - wx - hx;
py[3] = roi_y - wy - hy;
px[4] = roi_x + wx - hx;
py[4] = roi_y + wy - hy;
}
const scalar_t* offset_bottom_data =
bottom_data + (n * channels + c) * height * width;
scalar_t output_val = bottom_data[index];
for (int i = 0; i < points; i++) {
output_val += bilinear_interpolate<scalar_t>(offset_bottom_data, height,
width, py[i], px[i], i);
}
top_data[index] = output_val;
}
}
template <typename T>
void bilinear_interpolate_gradient(const int height, const int width, T y, T x,
T& w1, T& w2, T& w3, T& w4, int& x_low,
int& x_high, int& y_low, int& y_high,
const int index) {
// deal with cases that inverse elements are out of feature map boundary
if (y < -1.0 || y > height || x < -1.0 || x > width) {
// empty
w1 = w2 = w3 = w4 = 0.;
x_low = x_high = y_low = y_high = -1;
return;
}
if (y <= 0) y = 0;
if (x <= 0) x = 0;
y_low = (int)y;
x_low = (int)x;
if (y_low >= height - 1) {
y_high = y_low = height - 1;
y = (T)y_low;
} else {
y_high = y_low + 1;
}
if (x_low >= width - 1) {
x_high = x_low = width - 1;
x = (T)x_low;
} else {
x_high = x_low + 1;
}
T ly = y - y_low;
T lx = x - x_low;
T hy = 1. - ly, hx = 1. - lx;
w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
return;
}
template <typename scalar_t>
inline void valueAdd(scalar_t* address, scalar_t val) {
scalar_t old = *address;
*address = (old + val);
}
template <typename scalar_t>
void rotated_feature_align_backward_cpu_kernel(
const int nthreads, const int points, const scalar_t* top_diff,
const scalar_t* best_bboxes, const scalar_t spatial_scale,
const int channels, const int height, const int width,
scalar_t* bottom_diff) {
for (int index = 0; index < nthreads; index++) {
int w = index % width;
int h = (index / width) % height;
int c = (index / width / height) % channels;
int n = index / width / height / channels;
const scalar_t* bbox_offset =
best_bboxes + ((n * height + h) * width + w) * 5;
scalar_t roi_y = bbox_offset[0] * spatial_scale;
scalar_t roi_x = bbox_offset[1] * spatial_scale;
scalar_t px[5] = {roi_x, 0, 0, 0, 0};
scalar_t py[5] = {roi_y, 0, 0, 0, 0};
if (points > 1) {
scalar_t roi_w = bbox_offset[2] * spatial_scale;
scalar_t roi_h = bbox_offset[3] * spatial_scale;
scalar_t roi_a = bbox_offset[4];
scalar_t w_2 = roi_w / 2, h_2 = roi_h / 2;
scalar_t cosa = cosf(roi_a), sina = sinf(roi_a);
scalar_t wx = cosa * w_2, wy = sina * w_2;
scalar_t hx = -sina * h_2, hy = cosa * h_2;
px[1] = roi_x + wx + hx;
py[1] = roi_y + wy + hy;
px[2] = roi_x - wx + hx;
py[2] = roi_y - wy + hy;
px[3] = roi_x - wx - hx;
py[3] = roi_y - wy - hy;
px[4] = roi_x + wx - hx;
py[4] = roi_y + wy - hy;
}
scalar_t* offset_bottom_diff =
bottom_diff + (n * channels + c) * height * width;
scalar_t value_top_diff = top_diff[index];
valueAdd(bottom_diff + index, value_top_diff);
for (int i = 0; i < points; i++) {
scalar_t w1, w2, w3, w4;
int x_low, x_high, y_low, y_high;
bilinear_interpolate_gradient<scalar_t>(height, width, py[i], px[i], w1,
w2, w3, w4, x_low, x_high, y_low,
y_high, i);
scalar_t g1 = value_top_diff * w1;
scalar_t g2 = value_top_diff * w2;
scalar_t g3 = value_top_diff * w3;
scalar_t g4 = value_top_diff * w4;
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
valueAdd(offset_bottom_diff + y_low * width + x_low, g1);
valueAdd(offset_bottom_diff + y_low * width + x_high, g2);
valueAdd(offset_bottom_diff + y_high * width + x_low, g3);
valueAdd(offset_bottom_diff + y_high * width + x_high, g4);
}
}
}
}
void rotated_feature_align_forward_cpu(const Tensor features,
const Tensor best_bboxes,
const float spatial_scale,
const int points, Tensor output) {
const int output_size = features.numel();
AT_DISPATCH_FLOATING_TYPES(
features.scalar_type(), "rotated_feature_align_forward_cpu_kernel", [&] {
const scalar_t* bottom_data = features.data_ptr<scalar_t>();
const scalar_t* bboxes_data = best_bboxes.data_ptr<scalar_t>();
scalar_t* top_data = output.data_ptr<scalar_t>();
rotated_feature_align_forward_cpu_kernel<scalar_t>(
output_size, points, bottom_data, bboxes_data,
scalar_t(spatial_scale), features.size(1), features.size(2),
features.size(3), top_data);
});
}
void rotated_feature_align_backward_cpu(const Tensor top_grad,
const Tensor best_bboxes,
const float spatial_scale,
const int points, Tensor bottom_grad) {
const int output_size = top_grad.numel();
AT_DISPATCH_FLOATING_TYPES(
top_grad.scalar_type(), "rotated_feature_align_backward_cpu_kernel", [&] {
const scalar_t* top_diff = top_grad.data_ptr<scalar_t>();
const scalar_t* bboxes_data = best_bboxes.data_ptr<scalar_t>();
scalar_t* bottom_diff = bottom_grad.data_ptr<scalar_t>();
rotated_feature_align_backward_cpu_kernel<scalar_t>(
output_size, points, top_diff, bboxes_data, scalar_t(spatial_scale),
top_grad.size(1), top_grad.size(2), top_grad.size(3), bottom_diff);
});
}
void rotated_feature_align_forward_impl(const Tensor features,
const Tensor best_bboxes,
const float spatial_scale,
const int points, Tensor output);
void rotated_feature_align_backward_impl(const Tensor top_grad,
const Tensor best_bboxes,
const float spatial_scale,
const int points, Tensor bottom_grad);
REGISTER_DEVICE_IMPL(rotated_feature_align_forward_impl, CPU,
rotated_feature_align_forward_cpu);
REGISTER_DEVICE_IMPL(rotated_feature_align_backward_impl, CPU,
rotated_feature_align_backward_cpu);
......@@ -20,6 +20,16 @@ class RotatedFeatureAlignFunction(Function):
Object <https://arxiv.org/abs/1908.05612>`_.
"""
@staticmethod
def symbolic(g, features, best_rbboxes, spatial_scale, points):
assert points in [1, 5]
return g.op(
'mmcv::MMCVRotatedFeatureAlign',
features,
best_rbboxes,
spatial_scale_f=spatial_scale,
points_i=points)
@staticmethod
def forward(ctx, features, best_rbboxes, spatial_scale, points):
"""
......
......@@ -485,6 +485,121 @@ def test_interpolate():
assert np.allclose(pytorch_result, onnx_result, atol=1e-3)
def test_rotated_feature_align():
if torch.__version__ == 'parrots':
pytest.skip('onnx is not supported in parrots directly')
try:
from mmcv.ops import get_onnxruntime_op_path, rotated_feature_align
except (ImportError, ModuleNotFoundError):
pytest.skip('rotated_feature_align op is not successfully compiled')
ort_custom_op_path = get_onnxruntime_op_path()
if not os.path.exists(ort_custom_op_path):
pytest.skip('custom ops for onnxruntime are not compiled.')
spatial_scale = 1.0 / 8
points = 1
def warpped_function(feature, bbox):
return rotated_feature_align(
feature, bbox, spatial_scale=spatial_scale, points=points)
feature = torch.tensor([[[[1.2924, -0.2172, -0.5222, 0.1172],
[0.9144, 1.2248, 1.3115, -0.9690],
[-0.8949, -1.1797, -0.9093, -0.3961],
[-0.4586, 0.5062, -0.7947, -0.7397]],
[[-1.0943, -0.7495, 1.3461, -1.1652],
[0.2034, 0.6763, -1.2357, 0.5231],
[-1.0062, 1.2592, 1.4225, -0.3951],
[-0.1242, -1.6240, 0.1932, 2.7181]],
[[-1.6271, -1.0276, 0.0578, -0.2997],
[-0.9684, -1.6946, -1.3188, -1.1938],
[-1.6744, -0.8917, -0.6556, 1.0073],
[-0.1205, 0.3671, -0.3731, -0.5347]]],
[[[0.7035, 0.2089, -0.1774, 3.4670],
[-0.8505, -0.9278, 1.4714, 0.1644],
[0.0898, 0.3531, -0.4007, 0.1927],
[1.2569, -0.2636, -0.5223, 0.0616]],
[[0.1760, -0.7639, -0.4600, -1.3260],
[-0.9921, -0.2970, -0.8955, 1.0508],
[1.3515, -0.1641, 1.9679, 1.1986],
[-0.3616, 0.6287, 0.4933, 0.3360]],
[[-0.5860, 0.2124, -0.8700, 2.4200],
[-0.0551, -1.5103, -1.6779, 0.8399],
[0.8431, 1.2414, -1.1243, -0.3887],
[-2.1254, 0.6047, -0.3515, 0.7254]]]])
bbox = torch.tensor(
[[[[1.3080e+01, 1.2688e+01, 1.1214e+01, 9.3944e+01, -9.1905e-01],
[3.8104e+01, 1.0134e+01, 1.4659e+02, 9.0306e+01, -9.8211e-01],
[-5.3213e+01, 4.9508e+01, 5.1513e+01, 3.2055e+01, -3.1954e-01],
[2.6974e+01, 2.5248e+01, 5.4495e+01, 3.1083e+00, -6.2127e-01]],
[[-1.5604e+01, -5.1908e+01, 2.3998e+02, 1.5008e+01, -1.2546e+00],
[3.1354e+01, -7.3635e+00, 6.7879e+01, 3.5081e+01, -3.3851e-01],
[-5.3292e+00, 9.1946e+00, 1.2834e+01, 1.0485e+01, -1.3039e+00],
[-2.3925e+01, 3.6623e+01, 3.9875e+01, 7.2009e+01, -6.5934e-01]],
[[7.2114e+01, -2.3781e+01, 2.9106e+01, 8.4501e+01, -1.1340e+00],
[2.6258e+01, -7.7034e+00, 1.7629e+02, 1.0615e+02, -1.2156e+00],
[3.8057e+01, 4.6016e+01, 1.2965e+01, 6.9384e+00, -1.0855e+00],
[2.4428e+01, -1.6189e+01, 2.0572e+02, 3.1622e+01, -1.5719e-01]],
[[3.8226e+00, 2.9608e+01, 1.4457e+01, 6.8179e+01, -9.1997e-01],
[2.5003e+01, -4.2490e+01, 9.6007e+01, 4.9086e+01, -1.4786e+00],
[8.5983e+01, 5.4980e+01, 7.8080e+01, 1.0003e+02, -1.0926e+00],
[9.9065e+00, 4.1457e+01, 5.9799e+00, 1.7973e+01, -5.6313e-01]]],
[[[-1.8244e+01, 4.6309e+00, 5.3010e+01, 2.4310e+01, -7.0345e-01],
[1.9419e+01, 3.6704e+01, 5.2390e+01, 5.4133e+01, -3.7730e-01],
[5.6387e+01, 2.3752e+01, 9.0441e+00, 1.7792e+01, -1.5583e+00],
[3.6303e+01, 1.6396e+01, 2.0283e+01, 1.9148e+01, -8.3419e-01]],
[[3.2169e+01, 3.0521e+01, 2.6283e+01, 1.9680e+02, -3.0454e-01],
[2.5788e+01, -3.2189e+01, 8.8882e+01, 1.0207e+02, -1.5328e+00],
[8.4676e+00, -1.6668e+01, 2.4657e+01, 1.1275e+02, -4.0388e-01],
[-1.0799e+01, 6.0422e+00, 9.5807e+00, 3.3677e+01, -3.5438e-01]],
[[6.9363e+01, 1.0850e+01, 2.5968e+01, 2.2311e+01, -1.6408e-01],
[2.8140e+00, 4.6843e+00, 3.1289e+00, 2.1480e+01, -6.7583e-01],
[2.6661e+01, 4.5290e+01, 6.1679e+00, 3.0005e+01, -8.9806e-01],
[5.0871e+00, 1.3234e+01, 9.2087e+01, 4.9622e+01, -2.8020e-01]],
[[-1.2643e+01, 2.5176e+01, 5.0488e+01, 5.4246e+01, -4.4840e-01],
[-3.4521e+01, 9.8435e-01, 5.2413e+01, 9.7996e+00, -8.4218e-01],
[4.9829e+01, -1.0808e+01, 2.9848e+01, 7.3579e+01, -6.2672e-01],
[8.0446e+01, 2.8064e+01, 4.5273e+01, 5.3809e+01, -1.2359e+00]]]])
# compute pytorch_output
with torch.no_grad():
pytorch_output = rotated_feature_align(
feature, bbox, spatial_scale=spatial_scale, points=points)
# export and load onnx model
wrapped_model = WrapFunction(warpped_function)
with torch.no_grad():
torch.onnx.export(
wrapped_model, (feature, bbox),
onnx_file,
export_params=True,
keep_initializers_as_inputs=True,
input_names=['feature', 'bbox'],
opset_version=11)
onnx_model = onnx.load(onnx_file)
session_options = rt.SessionOptions()
if os.path.exists(ort_custom_op_path):
session_options.register_custom_ops_library(ort_custom_op_path)
# compute onnx_output
input_all = [node.name for node in onnx_model.graph.input]
input_initializer = [node.name for node in onnx_model.graph.initializer]
net_feed_input = list(set(input_all) - set(input_initializer))
assert (len(net_feed_input) == 2)
sess = rt.InferenceSession(onnx_file, session_options)
onnx_output = sess.run(None, {
'feature': feature.detach().numpy(),
'bbox': bbox.detach().numpy()
})
onnx_output = onnx_output[0]
# allclose
assert np.allclose(pytorch_output, onnx_output, atol=1e-3)
@pytest.mark.parametrize('mode', ['top', 'bottom', 'left', 'right'])
def test_corner_pool(mode, opset=11):
if torch.__version__ == 'parrots':
......
......@@ -7,7 +7,8 @@ from mmcv.ops import rotated_feature_align
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_rotated_feature_align():
@pytest.mark.parametrize('device', ['cuda', 'cpu'])
def test_rotated_feature_align(device):
feature = torch.tensor([[[[1.2924, -0.2172, -0.5222, 0.1172],
[0.9144, 1.2248, 1.3115, -0.9690],
[-0.8949, -1.1797, -0.9093, -0.3961],
......@@ -32,7 +33,7 @@ def test_rotated_feature_align():
[-0.0551, -1.5103, -1.6779, 0.8399],
[0.8431, 1.2414, -1.1243, -0.3887],
[-2.1254, 0.6047, -0.3515, 0.7254]]]],
device='cuda',
device=device,
requires_grad=True)
bbox = torch.tensor(
......@@ -68,7 +69,7 @@ def test_rotated_feature_align():
[-3.4521e+01, 9.8435e-01, 5.2413e+01, 9.7996e+00, -8.4218e-01],
[4.9829e+01, -1.0808e+01, 2.9848e+01, 7.3579e+01, -6.2672e-01],
[8.0446e+01, 2.8064e+01, 4.5273e+01, 5.3809e+01, -1.2359e+00]]]],
device='cuda',
device=device,
requires_grad=True)
expected_output = torch.tensor([[[[1.1095, -0.2172, -0.5222, -0.6225],
......@@ -94,34 +95,24 @@ def test_rotated_feature_align():
[[-0.5860, 0.2124, -0.8700, 2.4200],
[-0.0551, -1.5103, -1.6779, 0.8399],
[0.8431, 0.8455, -1.1243, -1.5994],
[-2.1254, 0.6047, -0.3515,
0.7254]]]]).cuda()
[-2.1254, 0.6047, -0.3515, 0.7254]]]],
device=device)
expected_grad = torch.tensor([[[[1.0000, 1.8507, 1.1493, 1.5222],
[1.0000, 1.1511, 1.2139, 1.4778],
[1.0000, 1.2629, 1.3721, 1.0000],
[3.0000, 1.0000, 1.0000, 2.0000]],
[[1.0000, 1.8507, 1.1493, 1.5222],
[1.0000, 1.1511, 1.2139, 1.4778],
[1.0000, 1.2629, 1.3721, 1.0000],
[3.0000, 1.0000, 1.0000, 2.0000]],
[[1.0000, 1.8507, 1.1493, 1.5222],
[1.0000, 1.1511, 1.2139, 1.4778],
[1.0000, 1.2629, 1.3721, 1.0000],
[3.0000, 1.0000, 1.0000, 2.0000]]],
[[[1.2687, 1.5055, 1.2382, 1.0000],
[1.1458, 1.4258, 1.4160, 1.0000],
[1.0000, 1.0000, 1.0000, 1.0000],
[1.0000, 1.0000, 1.0000, 1.0000]],
[[1.2687, 1.5055, 1.2382, 1.0000],
[1.1458, 1.4258, 1.4160, 1.0000],
[1.0000, 1.0000, 1.0000, 1.0000],
[1.0000, 1.0000, 1.0000, 1.0000]],
[[1.2687, 1.5055, 1.2382, 1.0000],
[1.1458, 1.4258, 1.4160, 1.0000],
[1.0000, 1.0000, 1.0000, 1.0000],
[1.0000, 1.0000, 1.0000,
1.0000]]]]).cuda()
expected_grad = torch.tensor([
[[[1.0000, 1.8507, 1.1493, 1.5222], [1.0000, 1.1511, 1.2139, 1.4778],
[1.0000, 1.2629, 1.3721, 1.0000], [3.0000, 1.0000, 1.0000, 2.0000]],
[[1.0000, 1.8507, 1.1493, 1.5222], [1.0000, 1.1511, 1.2139, 1.4778],
[1.0000, 1.2629, 1.3721, 1.0000], [3.0000, 1.0000, 1.0000, 2.0000]],
[[1.0000, 1.8507, 1.1493, 1.5222], [1.0000, 1.1511, 1.2139, 1.4778],
[1.0000, 1.2629, 1.3721, 1.0000], [3.0000, 1.0000, 1.0000, 2.0000]]],
[[[1.2687, 1.5055, 1.2382, 1.0000], [1.1458, 1.4258, 1.4160, 1.0000],
[1.0000, 1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000, 1.0000]],
[[1.2687, 1.5055, 1.2382, 1.0000], [1.1458, 1.4258, 1.4160, 1.0000],
[1.0000, 1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000, 1.0000]],
[[1.2687, 1.5055, 1.2382, 1.0000], [1.1458, 1.4258, 1.4160, 1.0000],
[1.0000, 1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000, 1.0000]]]
],
device=device)
output = rotated_feature_align(
feature, bbox, spatial_scale=1 / 8, points=1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment