Unverified Commit a6d7dde7 authored by q.yao's avatar q.yao Committed by GitHub
Browse files

[Fix] Fix torch2.0 dcn/mdcn symbolic (#2695)

* fix

* fix lint
parent 5a45fac9
...@@ -108,8 +108,10 @@ class DeformConv2dFunction(Function): ...@@ -108,8 +108,10 @@ class DeformConv2dFunction(Function):
return output return output
ctx.save_for_backward(input, offset, weight) ctx.save_for_backward(input, offset, weight)
output = input.new_empty( output = input.new_empty([
DeformConv2dFunction._output_size(ctx, input, weight)) int(i)
for i in DeformConv2dFunction._output_size(ctx, input, weight)
])
ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones
......
...@@ -136,8 +136,10 @@ class ModulatedDeformConv2dFunction(Function): ...@@ -136,8 +136,10 @@ class ModulatedDeformConv2dFunction(Function):
ctx, input, offset, mask, weight, bias) ctx, input, offset, mask, weight, bias)
return output return output
ctx.save_for_backward(input, offset, mask, weight, bias) ctx.save_for_backward(input, offset, mask, weight, bias)
output = input.new_empty( output = input.new_empty([
ModulatedDeformConv2dFunction._output_size(ctx, input, weight)) int(i) for i in ModulatedDeformConv2dFunction._output_size(
ctx, input, weight)
])
ctx._bufs = [input.new_empty(0), input.new_empty(0)] ctx._bufs = [input.new_empty(0), input.new_empty(0)]
ext_module.modulated_deform_conv_forward( ext_module.modulated_deform_conv_forward(
input, input,
......
...@@ -3,7 +3,6 @@ import os ...@@ -3,7 +3,6 @@ import os
import numpy as np import numpy as np
import onnx import onnx
import onnxruntime as rt
import pytest import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -37,6 +36,7 @@ class WrapFunction(nn.Module): ...@@ -37,6 +36,7 @@ class WrapFunction(nn.Module):
def test_roialign(): def test_roialign():
rt = pytest.importorskip('onnxruntime')
try: try:
from mmcv.ops import roi_align from mmcv.ops import roi_align
except (ImportError, ModuleNotFoundError): except (ImportError, ModuleNotFoundError):
...@@ -106,6 +106,7 @@ def test_roialign(): ...@@ -106,6 +106,7 @@ def test_roialign():
@pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires GPU') @pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires GPU')
def test_roipool(): def test_roipool():
rt = pytest.importorskip('onnxruntime')
from mmcv.ops import roi_pool from mmcv.ops import roi_pool
# roi pool config # roi pool config
...@@ -204,7 +205,7 @@ def test_deform_conv(): ...@@ -204,7 +205,7 @@ def test_deform_conv():
from mmcv.ops import DeformConv2dPack from mmcv.ops import DeformConv2dPack
x = torch.randn(1, 2, 4, 4, device='cuda') x = torch.randn(1, 2, 4, 4, device='cuda')
_test_symbolic( _test_symbolic(
DeformConv2dPack(2, 4, 3, 1, 1).cuda(), x, 'MMCVDeformConv2d') DeformConv2dPack(2, 4, 3, 1, 1).cuda(), (x, ), 'MMCVDeformConv2d')
@pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires GPU') @pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires GPU')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment