Commit 3f2a16d7 authored by Weiyun1025's avatar Weiyun1025 Committed by zhe chen
Browse files

Update tensorrt

parent d4dc2427
...@@ -60,6 +60,33 @@ class DCNv3Function(Function): ...@@ -60,6 +60,33 @@ class DCNv3Function(Function):
return grad_input, grad_offset, grad_mask, \ return grad_input, grad_offset, grad_mask, \
None, None, None, None, None, None, None, None, None, None, None, None None, None, None, None, None, None, None, None, None, None, None, None
@staticmethod
def symbolic(g, input, offset, mask, kernel_h, kernel_w, stride_h,
stride_w, pad_h, pad_w, dilation_h, dilation_w, group,
group_channels, offset_scale, im2col_step):
"""Symbolic function for mmdeploy::DCNv3.
Returns:
DCNv3 op for onnx.
"""
return g.op(
'mmdeploy::TRTDCNv3',
input,
offset,
mask,
kernel_h_i=int(kernel_h),
kernel_w_i=int(kernel_w),
stride_h_i=int(stride_h),
stride_w_i=int(stride_w),
pad_h_i=int(pad_h),
pad_w_i=int(pad_w),
dilation_h_i=int(dilation_h),
dilation_w_i=int(dilation_w),
group_i=int(group),
group_channels_i=int(group_channels),
offset_scale_f=float(offset_scale),
im2col_step_i=int(im2col_step),
)
def _get_reference_points(spatial_shapes, device, kernel_h, kernel_w, dilation_h, dilation_w, pad_h=0, pad_w=0, stride_h=1, stride_w=1): def _get_reference_points(spatial_shapes, device, kernel_h, kernel_w, dilation_h, dilation_w, pad_h=0, pad_w=0, stride_h=1, stride_w=1):
_, H_, W_, _ = spatial_shapes _, H_, W_, _ = spatial_shapes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment