Unverified Commit a6d7dde7 authored by q.yao's avatar q.yao Committed by GitHub
Browse files

[Fix] Fix torch2.0 dcn/mdcn symbolic (#2695)

* fix

* fix lint
parent 5a45fac9
......@@ -108,8 +108,10 @@ class DeformConv2dFunction(Function):
return output
ctx.save_for_backward(input, offset, weight)
output = input.new_empty(
DeformConv2dFunction._output_size(ctx, input, weight))
output = input.new_empty([
int(i)
for i in DeformConv2dFunction._output_size(ctx, input, weight)
])
ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones
......
......@@ -136,8 +136,10 @@ class ModulatedDeformConv2dFunction(Function):
ctx, input, offset, mask, weight, bias)
return output
ctx.save_for_backward(input, offset, mask, weight, bias)
output = input.new_empty(
ModulatedDeformConv2dFunction._output_size(ctx, input, weight))
output = input.new_empty([
int(i) for i in ModulatedDeformConv2dFunction._output_size(
ctx, input, weight)
])
ctx._bufs = [input.new_empty(0), input.new_empty(0)]
ext_module.modulated_deform_conv_forward(
input,
......
......@@ -3,7 +3,6 @@ import os
import numpy as np
import onnx
import onnxruntime as rt
import pytest
import torch
import torch.nn as nn
......@@ -37,6 +36,7 @@ class WrapFunction(nn.Module):
def test_roialign():
rt = pytest.importorskip('onnxruntime')
try:
from mmcv.ops import roi_align
except (ImportError, ModuleNotFoundError):
......@@ -106,6 +106,7 @@ def test_roialign():
@pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires GPU')
def test_roipool():
rt = pytest.importorskip('onnxruntime')
from mmcv.ops import roi_pool
# roi pool config
......@@ -204,7 +205,7 @@ def test_deform_conv():
from mmcv.ops import DeformConv2dPack
x = torch.randn(1, 2, 4, 4, device='cuda')
_test_symbolic(
DeformConv2dPack(2, 4, 3, 1, 1).cuda(), x, 'MMCVDeformConv2d')
DeformConv2dPack(2, 4, 3, 1, 1).cuda(), (x, ), 'MMCVDeformConv2d')
@pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires GPU')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment