".github/git@developer.sourcefind.cn:OpenDAS/lmdeploy.git" did not exist on "5ea40abf613e47bb56a0c06f695644d55671f585"
Unverified Commit 1790e9f2 authored by Paige Wang's avatar Paige Wang Committed by GitHub
Browse files

add modulated_deform_conv in onnxruntime support (#1281)

* add modulated_deform_conv in onnxruntime support

* Add docs descriptions

* Add gpu test in test_onnx.py

* code format

* remove new usage and move if outside for loop

* use memset when bias is nullptr
parent 77cb5786
...@@ -45,6 +45,12 @@ ...@@ -45,6 +45,12 @@
- [Inputs](#inputs-6) - [Inputs](#inputs-6)
- [Outputs](#outputs-6) - [Outputs](#outputs-6)
- [Type Constraints](#type-constraints-6) - [Type Constraints](#type-constraints-6)
- [MMCVModulatedDeformConv2d](#mmcvmodulateddeformconv2d)
- [Description](#description-7)
- [Parameters](#parameters-7)
- [Inputs](#inputs-7)
- [Outputs](#outputs-7)
- [Type Constraints](#type-constraints-7)
<!-- TOC --> <!-- TOC -->
...@@ -283,3 +289,45 @@ Returns a tuple (`values`, `indices`) where `values` is the cumulative minimum e ...@@ -283,3 +289,45 @@ Returns a tuple (`values`, `indices`) where `values` is the cumulative minimum e
### Type Constraints ### Type Constraints
- T:tensor(float32) - T:tensor(float32)
## MMCVModulatedDeformConv2d
### Description
Perform Modulated Deformable Convolution on input feature, read [Deformable ConvNets v2: More Deformable, Better Results](https://arxiv.org/abs/1811.11168?from=timeline) for detail.
### Parameters
| Type | Parameter | Description |
| -------------- | ------------------ | ------------------------------------------------------------------------------------- |
| `list of ints` | `stride` | The stride of the convolving kernel. (sH, sW) |
| `list of ints` | `padding` | Paddings on both sides of the input. (padH, padW) |
| `list of ints` | `dilation` | The spacing between kernel elements. (dH, dW) |
| `int` | `deformable_groups` | Groups of deformable offset. |
| `int` | `groups` | Split input into groups. `input_channel` should be divisible by the number of groups. |
### Inputs
<dl>
<dt><tt>inputs[0]</tt>: T</dt>
<dd>Input feature; 4-D tensor of shape (N, C, inH, inW), where N is the batch size, C is the number of channels, inH and inW are the height and width of the data.</dd>
<dt><tt>inputs[1]</tt>: T</dt>
<dd>Input offset; 4-D tensor of shape (N, deformable_group* 2* kH* kW, outH, outW), where kH and kW is the height and width of weight, outH and outW is the height and width of offset and output.</dd>
<dt><tt>inputs[2]</tt>: T</dt>
<dd>Input mask; 4-D tensor of shape (N, deformable_group* kH* kW, outH, outW), where kH and kW is the height and width of weight, outH and outW is the height and width of offset and output.</dd>
<dt><tt>inputs[3]</tt>: T</dt>
<dd>Input weight; 4-D tensor of shape (output_channel, input_channel, kH, kW).</dd>
<dt><tt>inputs[4]</tt>: T, optional</dt>
<dd>Input bias; 1-D tensor of shape (output_channel).</dd>
</dl>
### Outputs
<dl>
<dt><tt>outputs[0]</tt>: T</dt>
<dd>Output feature; 4-D tensor of shape (N, output_channel, outH, outW).</dd>
</dl>
### Type Constraints
- T:tensor(float32, Linear)
...@@ -22,7 +22,7 @@ struct MMCVCornerPoolKernel { ...@@ -22,7 +22,7 @@ struct MMCVCornerPoolKernel {
struct MMCVCornerPoolCustomOp struct MMCVCornerPoolCustomOp
: Ort::CustomOpBase<MMCVCornerPoolCustomOp, MMCVCornerPoolKernel> { : Ort::CustomOpBase<MMCVCornerPoolCustomOp, MMCVCornerPoolKernel> {
void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) { void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) const {
return new MMCVCornerPoolKernel(api, info); return new MMCVCornerPoolKernel(api, info);
} }
......
// Copyright (c) OpenMMLab. All rights reserved
#include "modulated_deform_conv.h"
#include <cmath>
#include <vector>
#include "../ort_mmcv_utils.h"
float bilinear_interpolate_2d(const float *src, const int64_t src_h,
const int64_t src_w, const float h,
const float w) {
if (h <= -1 || src_h <= h || w <= -1 || src_w <= w) {
return 0;
}
int64_t h_low = floor(h);
int64_t w_low = floor(w);
int64_t h_high = h_low + 1;
int64_t w_high = w_low + 1;
float lh = h - h_low;
float lw = w - w_low;
float hh = 1 - lh;
float hw = 1 - lw;
float v1 = 0;
if (h_low >= 0 && w_low >= 0) v1 = src[h_low * src_w + w_low];
float v2 = 0;
if (h_low >= 0 && w_high <= src_w - 1) v2 = src[h_low * src_w + w_high];
float v3 = 0;
if (h_high <= src_h - 1 && w_low >= 0) v3 = src[h_high * src_w + w_low];
float v4 = 0;
if (h_high <= src_h - 1 && w_high <= src_w - 1)
v4 = src[h_high * src_w + w_high];
float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
// output: (channels * kernel_h * kernel_w, dst_h * dst_w)
void deformable_im2col_2d(const float *input, const float *offset,
const float *mask, const int64_t src_h,
const int64_t src_w, const int64_t kernel_h,
const int64_t kernel_w, const int64_t pad_h,
const int64_t pad_w, const int64_t stride_h,
const int64_t stride_w, const int64_t dilation_h,
const int64_t dilation_w, const int64_t channels,
const int64_t offset_groups, const int64_t dst_h,
const int64_t dst_w, const bool use_mask,
float *columns) {
const int64_t workload = channels * dst_h * dst_w;
for (int64_t index = 0; index != workload; ++index) {
const int64_t ow = index % dst_w;
const int64_t oh = (index / dst_w) % dst_h;
const int64_t ic = index / (dst_w * dst_h);
const int64_t oc = ic * kernel_h * kernel_w;
int64_t c_per_offset_grp = channels / offset_groups;
const int64_t grp_idx = ic / c_per_offset_grp;
auto columns_ptr = columns + (oc * (dst_h * dst_w) + oh * dst_w + ow);
auto input_ptr = input + ic * (src_h * src_w);
auto offset_ptr =
offset + grp_idx * 2 * kernel_h * kernel_w * dst_h * dst_w;
auto mask_ptr = mask;
if (use_mask) {
mask_ptr += grp_idx * kernel_h * kernel_w * dst_h * dst_w;
}
for (int64_t kh = 0; kh < kernel_h; ++kh) {
for (int64_t kw = 0; kw < kernel_w; ++kw) {
const int64_t mask_idx = kh * kernel_w + kw;
const int64_t offset_idx = 2 * mask_idx;
float mask_value = 1;
if (use_mask) {
mask_value = mask_ptr[mask_idx * (dst_h * dst_w) + oh * dst_w + ow];
}
const float offset_h =
offset_ptr[offset_idx * (dst_h * dst_w) + oh * dst_w + ow];
const float offset_w =
offset_ptr[(offset_idx + 1) * (dst_h * dst_w) + oh * dst_w + ow];
const float ih = (oh * stride_h - pad_h) + kh * dilation_h + offset_h;
const float iw = (ow * stride_w - pad_w) + kw * dilation_w + offset_w;
*columns_ptr = mask_value *
bilinear_interpolate_2d(input_ptr, src_h, src_w, ih, iw);
columns_ptr += dst_h * dst_w;
}
}
}
}
void gemm_ref_fp32(const float *A, const float *B, const float *V,
const float *H, const int32_t trans_A, const int32_t trans_B,
const int32_t M, const int32_t N, const int32_t K,
const float alpha, const float beta, float *Y) {
if (!trans_A && !trans_B) { // MK, KN; NN
for (int64_t m = 0; m < M; ++m) {
for (int64_t n = 0; n < N; ++n) {
float y = 0.0f;
for (int64_t k = 0; k < K; ++k) {
y += A[m * K + k] * B[k * N + n];
}
y *= alpha;
if (V) y += beta * V[n];
if (H) y += beta * H[m * N + n];
Y[m * N + n] = y;
}
}
}
if (trans_A && !trans_B) { // KM, KN; TN
for (int64_t m = 0; m < M; ++m) {
for (int64_t n = 0; n < N; ++n) {
float y = 0.0f;
for (int64_t k = 0; k < K; ++k) {
y += A[k * M + m] * B[k * N + n];
}
y *= alpha;
if (V) y += beta * V[n];
if (H) y += beta * H[m * N + n];
Y[m * N + n] = y;
}
}
}
if (trans_A && trans_B) { // KM, NK; TT
for (int64_t m = 0; m < M; ++m) {
for (int64_t n = 0; n < N; ++n) {
float y = 0.0f;
for (int64_t k = 0; k < K; ++k) {
y += A[k * M + m] * B[n * K + k];
}
y *= alpha;
if (V) y += beta * V[n];
if (H) y += beta * H[m * N + n];
Y[m * N + n] = y;
}
}
}
if (!trans_A && trans_B) { // MK, NK; NT
for (int64_t m = 0; m < M; ++m) {
for (int64_t n = 0; n < N; ++n) {
float y = 0.0f;
for (int64_t k = 0; k < K; ++k) {
y += A[m * K + k] * B[n * K + k];
}
y *= alpha;
if (V) y += beta * V[n];
if (H) y += beta * H[m * N + n];
Y[m * N + n] = y;
}
}
}
}
void deformable_conv2d_ref_fp32(
const float *src, const float *offset, const float *mask,
const float *filter, const float *bias, const int64_t batch,
const int64_t src_c, const int64_t src_h, const int64_t src_w,
const int64_t dst_c, const int64_t dst_h, const int64_t dst_w,
const int64_t group, const int64_t offset_group, const int64_t channels,
const int64_t num_output, const int64_t kernel_h, const int64_t kernel_w,
const int64_t stride_h, const int64_t stride_w, const int64_t pad_h,
const int64_t pad_w, const int64_t dilation_h, const int64_t dilation_w,
float *columns, float *dst) {
const int64_t ic_per_gp = channels / group;
const int64_t oc_per_gp = num_output / group;
for (int64_t b = 0; b < batch; ++b) {
for (int64_t g = 0; g < group; ++g) {
deformable_im2col_2d(
src + b * src_c * src_h * src_w + g * ic_per_gp * src_h * src_w,
offset + b * offset_group * 2 * kernel_h * kernel_w * dst_h * dst_w,
mask + b * offset_group * kernel_h * kernel_w * dst_h * dst_w, src_h,
src_w, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, ic_per_gp, offset_group, dst_h, dst_w,
mask != nullptr, columns);
float *dst_ptr =
dst + b * dst_c * dst_h * dst_w + g * oc_per_gp * dst_h * dst_w;
if (bias != nullptr) {
const float *bias_ptr = bias + g * oc_per_gp;
for (int64_t oc = 0; oc < oc_per_gp; ++oc) {
for (int64_t hw = 0; hw < dst_h * dst_w; ++hw) {
dst_ptr[oc * dst_h * dst_w + hw] = bias_ptr[oc];
}
}
} else {
memset(dst_ptr, 0.0f, sizeof(float) * oc_per_gp * dst_h * dst_w);
}
gemm_ref_fp32(filter + g * oc_per_gp * ic_per_gp * kernel_h * kernel_w,
columns, nullptr, dst_ptr, 0, 0, oc_per_gp, dst_h * dst_w,
ic_per_gp * kernel_h * kernel_w, 1.0f, 1.0f, dst_ptr);
}
}
}
MMCVModulatedDeformConvKernel::MMCVModulatedDeformConvKernel(
OrtApi api, const OrtKernelInfo *info)
: api_(api), ort_(api_), info_(info) {
std::vector<int64_t> stride =
ort_.KernelInfoGetAttribute<std::vector<int64_t>>(info, "stride");
stride_height_ = stride[0];
stride_width_ = stride[1];
std::vector<int64_t> padding =
ort_.KernelInfoGetAttribute<std::vector<int64_t>>(info, "padding");
padding_height_ = padding[0];
padding_width_ = padding[1];
std::vector<int64_t> dilation =
ort_.KernelInfoGetAttribute<std::vector<int64_t>>(info, "dilation");
dilation_height_ = dilation[0];
dilation_width_ = dilation[1];
deformable_group_ =
ort_.KernelInfoGetAttribute<int64_t>(info, "deform_groups");
group_ = ort_.KernelInfoGetAttribute<int64_t>(info, "groups");
// create allocator
allocator_ = Ort::AllocatorWithDefaultOptions();
}
void MMCVModulatedDeformConvKernel::Compute(OrtKernelContext *context) {
const int64_t stride_height = stride_height_;
const int64_t stride_width = stride_width_;
const int64_t padding_height = padding_height_;
const int64_t padding_width = padding_width_;
const int64_t dilation_height = dilation_height_;
const int64_t dilation_width = dilation_width_;
const int64_t deformable_group = deformable_group_;
const int64_t group = group_;
const OrtValue *input = ort_.KernelContext_GetInput(context, 0);
const float *input_data =
reinterpret_cast<const float *>(ort_.GetTensorData<float>(input));
const OrtValue *offset = ort_.KernelContext_GetInput(context, 1);
const float *offset_data =
reinterpret_cast<const float *>(ort_.GetTensorData<float>(offset));
const OrtValue *mask = ort_.KernelContext_GetInput(context, 2);
const float *mask_data =
reinterpret_cast<const float *>(ort_.GetTensorData<float>(mask));
const OrtValue *filter = ort_.KernelContext_GetInput(context, 3);
const float *filter_data =
reinterpret_cast<const float *>(ort_.GetTensorData<float>(filter));
const OrtValue *bias = ort_.KernelContext_GetInput(context, 4);
const float *bias_data =
(bias != nullptr)
? reinterpret_cast<const float *>(ort_.GetTensorData<float>(bias))
: nullptr;
// const float *bias_data = nullptr;
OrtTensorDimensions input_dims(ort_, input);
OrtTensorDimensions filter_dims(ort_, filter);
int64_t batch = input_dims[0];
int64_t channels = input_dims[1];
int64_t in_height = input_dims[2];
int64_t in_width = input_dims[3];
int64_t num_output = filter_dims[0];
int64_t kernel_height = filter_dims[2];
int64_t kernel_width = filter_dims[3];
// get output memory
int64_t out_height = floor((in_height + 2 * padding_height -
dilation_height * (kernel_height - 1) - 1) /
stride_height +
1);
int64_t out_width = floor(
(in_width + 2 * padding_width - dilation_width * (kernel_width - 1) - 1) /
stride_width +
1);
std::vector<int64_t> output_dims = {batch, num_output, out_height, out_width};
OrtValue *output = ort_.KernelContext_GetOutput(
context, 0, output_dims.data(), output_dims.size());
float *out_ptr = ort_.GetTensorMutableData<float>(output);
// allocate tmp memory
int64_t column_len = (channels / group) * kernel_height * kernel_width *
out_height * out_width;
float *columns = (float *)allocator_.Alloc(sizeof(float) * column_len);
deformable_conv2d_ref_fp32(
input_data, offset_data, mask_data, filter_data, bias_data, batch,
channels, in_height, in_width, num_output, out_height, out_width, group,
deformable_group, channels, num_output, kernel_height, kernel_width,
stride_height, stride_width, padding_height, padding_width,
dilation_height, dilation_width, columns, out_ptr);
}
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "corner_pool.h" #include "corner_pool.h"
#include "grid_sample.h" #include "grid_sample.h"
#include "modulated_deform_conv.h"
#include "nms.h" #include "nms.h"
#include "ort_mmcv_utils.h" #include "ort_mmcv_utils.h"
#include "reduce_ops.h" #include "reduce_ops.h"
...@@ -19,6 +20,7 @@ GridSampleOp c_GridSampleOp; ...@@ -19,6 +20,7 @@ GridSampleOp c_GridSampleOp;
MMCVCumMaxCustomOp c_MMCVCumMaxCustomOp; MMCVCumMaxCustomOp c_MMCVCumMaxCustomOp;
MMCVCumMinCustomOp c_MMCVCumMinCustomOp; MMCVCumMinCustomOp c_MMCVCumMinCustomOp;
MMCVCornerPoolCustomOp c_MMCVCornerPoolCustomOp; MMCVCornerPoolCustomOp c_MMCVCornerPoolCustomOp;
MMCVModulatedDeformConvOp c_MMCVModulatedDeformConvOp;
OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options, OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options,
const OrtApiBase *api) { const OrtApiBase *api) {
...@@ -64,5 +66,10 @@ OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options, ...@@ -64,5 +66,10 @@ OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options,
return status; return status;
} }
if (auto status =
ortApi->CustomOpDomain_Add(domain, &c_MMCVModulatedDeformConvOp)) {
return status;
}
return ortApi->AddCustomOpDomain(options, domain); return ortApi->AddCustomOpDomain(options, domain);
} }
// Copyright (c) OpenMMLab. All rights reserved
#ifndef ONNXRUNTIME_MODULATED_DEFORM_CONV_H
#define ONNXRUNTIME_MODULATED_DEFORM_CONV_H
#include <onnxruntime_cxx_api.h>
struct MMCVModulatedDeformConvKernel {
MMCVModulatedDeformConvKernel(OrtApi api, const OrtKernelInfo *info);
void Compute(OrtKernelContext *context);
protected:
OrtApi api_;
Ort::CustomOpApi ort_;
const OrtKernelInfo *info_;
Ort::AllocatorWithDefaultOptions allocator_;
int64_t stride_height_;
int64_t stride_width_;
int64_t padding_height_;
int64_t padding_width_;
int64_t dilation_height_;
int64_t dilation_width_;
int64_t deformable_group_;
int64_t group_;
};
struct MMCVModulatedDeformConvOp
: Ort::CustomOpBase<MMCVModulatedDeformConvOp,
MMCVModulatedDeformConvKernel> {
void *CreateKernel(OrtApi api, const OrtKernelInfo *info) const {
return new MMCVModulatedDeformConvKernel(api, info);
}
const char *GetName() const { return "MMCVModulatedDeformConv2d"; };
size_t GetInputTypeCount() const { return 5; };
ONNXTensorElementDataType GetInputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
};
OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(
size_t index) const {
// The last input (index == 4) is optional, which is bias
if (index == 4)
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_OPTIONAL;
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
}
size_t GetOutputTypeCount() const { return 1; };
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
};
// force cpu
const char *GetExecutionProviderType() const {
return "CPUExecutionProvider";
};
};
#endif
...@@ -44,7 +44,7 @@ struct MMCVCumMinKernel { ...@@ -44,7 +44,7 @@ struct MMCVCumMinKernel {
struct MMCVCumMaxCustomOp struct MMCVCumMaxCustomOp
: Ort::CustomOpBase<MMCVCumMaxCustomOp, MMCVCumMaxKernel> { : Ort::CustomOpBase<MMCVCumMaxCustomOp, MMCVCumMaxKernel> {
void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) { void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) const {
return new MMCVCumMaxKernel(api, info); return new MMCVCumMaxKernel(api, info);
} }
...@@ -69,7 +69,7 @@ struct MMCVCumMaxCustomOp ...@@ -69,7 +69,7 @@ struct MMCVCumMaxCustomOp
struct MMCVCumMinCustomOp struct MMCVCumMinCustomOp
: Ort::CustomOpBase<MMCVCumMinCustomOp, MMCVCumMinKernel> { : Ort::CustomOpBase<MMCVCumMinCustomOp, MMCVCumMinKernel> {
void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) { void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) const {
return new MMCVCumMinKernel(api, info); return new MMCVCumMinKernel(api, info);
} }
......
...@@ -39,7 +39,7 @@ struct MMCVRoiAlignKernel { ...@@ -39,7 +39,7 @@ struct MMCVRoiAlignKernel {
struct MMCVRoiAlignCustomOp struct MMCVRoiAlignCustomOp
: Ort::CustomOpBase<MMCVRoiAlignCustomOp, MMCVRoiAlignKernel> { : Ort::CustomOpBase<MMCVRoiAlignCustomOp, MMCVRoiAlignKernel> {
void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) { void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) const {
return new MMCVRoiAlignKernel(api, info); return new MMCVRoiAlignKernel(api, info);
} }
const char* GetName() const { return "MMCVRoiAlign"; } const char* GetName() const { return "MMCVRoiAlign"; }
......
...@@ -39,7 +39,7 @@ struct MMCVRoIAlignRotatedKernel { ...@@ -39,7 +39,7 @@ struct MMCVRoIAlignRotatedKernel {
struct MMCVRoIAlignRotatedCustomOp struct MMCVRoIAlignRotatedCustomOp
: Ort::CustomOpBase<MMCVRoIAlignRotatedCustomOp, : Ort::CustomOpBase<MMCVRoIAlignRotatedCustomOp,
MMCVRoIAlignRotatedKernel> { MMCVRoIAlignRotatedKernel> {
void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) { void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) const {
return new MMCVRoIAlignRotatedKernel(api, info); return new MMCVRoIAlignRotatedKernel(api, info);
} }
const char* GetName() const { return "MMCVRoIAlignRotated"; } const char* GetName() const { return "MMCVRoIAlignRotated"; }
......
...@@ -22,7 +22,7 @@ struct SoftNmsKernel { ...@@ -22,7 +22,7 @@ struct SoftNmsKernel {
}; };
struct SoftNmsOp : Ort::CustomOpBase<SoftNmsOp, SoftNmsKernel> { struct SoftNmsOp : Ort::CustomOpBase<SoftNmsOp, SoftNmsKernel> {
void *CreateKernel(OrtApi api, const OrtKernelInfo *info) { void *CreateKernel(OrtApi api, const OrtKernelInfo *info) const {
return new SoftNmsKernel(api, info); return new SoftNmsKernel(api, info);
}; };
......
...@@ -638,3 +638,95 @@ def test_roll(shifts_dims_pair): ...@@ -638,3 +638,95 @@ def test_roll(shifts_dims_pair):
pytorch_output = wrapped_model(input.clone()) pytorch_output = wrapped_model(input.clone())
torch.testing.assert_allclose(ort_output, pytorch_output) torch.testing.assert_allclose(ort_output, pytorch_output)
@pytest.mark.skipif(
torch.__version__ == 'parrots',
reason='onnx is not supported in parrots directly')
@pytest.mark.skipif(
not torch.cuda.is_available(),
reason='modulated_deform_conv2d only supports in GPU')
def test_modulated_deform_conv2d():
try:
from mmcv.ops import ModulatedDeformConv2d
from mmcv.ops import get_onnxruntime_op_path
except (ImportError, ModuleNotFoundError):
pytest.skip('modulated_deform_conv op is not successfully compiled')
ort_custom_op_path = get_onnxruntime_op_path()
# modulated deform conv config
in_channels = 3
out_channels = 64
stride = 1
padding = 0
dilation = 1
groups = 1
deform_groups = 1
kernel_size = 3
input = torch.rand(1, in_channels, 28, 28).cuda() # (n, c, h, w)
conv_offset = nn.Conv2d(
in_channels=3,
out_channels=deform_groups * 3 * kernel_size * kernel_size,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=True).cuda()
conv_offset.cuda()
out = conv_offset(input)
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask)
model_with_bias = ModulatedDeformConv2d(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
deform_groups,
bias=True)
model_without_bias = ModulatedDeformConv2d(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
deform_groups,
bias=False)
models = [model_with_bias.cuda(), model_without_bias.cuda()]
for model in models:
# export and load onnx model
with torch.no_grad():
torch.onnx.export(
model, (input, offset, mask),
onnx_file,
export_params=True,
keep_initializers_as_inputs=True,
input_names=['input', 'offset', 'mask'],
opset_version=11)
session_options = rt.SessionOptions()
if os.path.exists(ort_custom_op_path):
session_options.register_custom_ops_library(ort_custom_op_path)
# compute onnx_output
sess = rt.InferenceSession(onnx_file, session_options)
onnx_output = sess.run(
None, {
'input': input.cpu().detach().numpy(),
'offset': offset.cpu().detach().numpy(),
'mask': mask.cpu().detach().numpy()
})[0]
# compute pytorch_output
with torch.no_grad():
pytorch_output = model(input, offset, mask).cpu()
# allclose
assert np.allclose(pytorch_output, onnx_output, atol=1e-3)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment