Unverified Commit 8e3a8015 authored by BigBigDream's avatar BigBigDream Committed by GitHub
Browse files

fix mmcv ci for parrots (#782)

* fix mmcv ci for parrots

* fix mmcv ci

* fix lint
parent f169fb52
...@@ -4,7 +4,6 @@ from torch.autograd import Function ...@@ -4,7 +4,6 @@ from torch.autograd import Function
from torch.autograd.function import once_differentiable from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair from torch.nn.modules.utils import _pair
from ..onnx import is_custom_op_loaded
from ..utils import deprecated_api_warning, ext_loader from ..utils import deprecated_api_warning, ext_loader
ext_module = ext_loader.load_ext('_ext', ext_module = ext_loader.load_ext('_ext',
...@@ -16,6 +15,7 @@ class RoIAlignFunction(Function): ...@@ -16,6 +15,7 @@ class RoIAlignFunction(Function):
@staticmethod @staticmethod
def symbolic(g, input, rois, output_size, spatial_scale, sampling_ratio, def symbolic(g, input, rois, output_size, spatial_scale, sampling_ratio,
pool_mode, aligned): pool_mode, aligned):
from ..onnx import is_custom_op_loaded
has_custom_op = is_custom_op_loaded() has_custom_op = is_custom_op_loaded()
if has_custom_op: if has_custom_op:
return g.op( return g.op(
......
...@@ -3,6 +3,7 @@ from collections import OrderedDict ...@@ -3,6 +3,7 @@ from collections import OrderedDict
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
import torch
import torch.nn as nn import torch.nn as nn
from torch.nn.parallel import DataParallel from torch.nn.parallel import DataParallel
...@@ -47,11 +48,19 @@ def assert_tensor_equal(tensor_a, tensor_b): ...@@ -47,11 +48,19 @@ def assert_tensor_equal(tensor_a, tensor_b):
def test_get_state_dict(): def test_get_state_dict():
state_dict_keys = set([ if torch.__version__ == 'parrots':
'block.conv.weight', 'block.conv.bias', 'block.norm.weight', state_dict_keys = set([
'block.norm.bias', 'block.norm.running_mean', 'block.norm.running_var', 'block.conv.weight', 'block.conv.bias', 'block.norm.weight',
'block.norm.num_batches_tracked', 'conv.weight', 'conv.bias' 'block.norm.bias', 'block.norm.running_mean',
]) 'block.norm.running_var', 'conv.weight', 'conv.bias'
])
else:
state_dict_keys = set([
'block.conv.weight', 'block.conv.bias', 'block.norm.weight',
'block.norm.bias', 'block.norm.running_mean',
'block.norm.running_var', 'block.norm.num_batches_tracked',
'conv.weight', 'conv.bias'
])
model = Model() model = Model()
state_dict = get_state_dict(model) state_dict = get_state_dict(model)
...@@ -68,8 +77,9 @@ def test_get_state_dict(): ...@@ -68,8 +77,9 @@ def test_get_state_dict():
model.block.norm.running_mean) model.block.norm.running_mean)
assert_tensor_equal(state_dict['block.norm.running_var'], assert_tensor_equal(state_dict['block.norm.running_var'],
model.block.norm.running_var) model.block.norm.running_var)
assert_tensor_equal(state_dict['block.norm.num_batches_tracked'], if torch.__version__ != 'parrots':
model.block.norm.num_batches_tracked) assert_tensor_equal(state_dict['block.norm.num_batches_tracked'],
model.block.norm.num_batches_tracked)
assert_tensor_equal(state_dict['conv.weight'], model.conv.weight) assert_tensor_equal(state_dict['conv.weight'], model.conv.weight)
assert_tensor_equal(state_dict['conv.bias'], model.conv.bias) assert_tensor_equal(state_dict['conv.bias'], model.conv.bias)
...@@ -89,8 +99,10 @@ def test_get_state_dict(): ...@@ -89,8 +99,10 @@ def test_get_state_dict():
wrapped_model.module.block.norm.running_mean) wrapped_model.module.block.norm.running_mean)
assert_tensor_equal(state_dict['block.norm.running_var'], assert_tensor_equal(state_dict['block.norm.running_var'],
wrapped_model.module.block.norm.running_var) wrapped_model.module.block.norm.running_var)
assert_tensor_equal(state_dict['block.norm.num_batches_tracked'], if torch.__version__ != 'parrots':
wrapped_model.module.block.norm.num_batches_tracked) assert_tensor_equal(
state_dict['block.norm.num_batches_tracked'],
wrapped_model.module.block.norm.num_batches_tracked)
assert_tensor_equal(state_dict['conv.weight'], assert_tensor_equal(state_dict['conv.weight'],
wrapped_model.module.conv.weight) wrapped_model.module.conv.weight)
assert_tensor_equal(state_dict['conv.bias'], assert_tensor_equal(state_dict['conv.bias'],
...@@ -115,9 +127,10 @@ def test_get_state_dict(): ...@@ -115,9 +127,10 @@ def test_get_state_dict():
wrapped_model.module.block.module.norm.running_mean) wrapped_model.module.block.module.norm.running_mean)
assert_tensor_equal(state_dict['block.norm.running_var'], assert_tensor_equal(state_dict['block.norm.running_var'],
wrapped_model.module.block.module.norm.running_var) wrapped_model.module.block.module.norm.running_var)
assert_tensor_equal( if torch.__version__ != 'parrots':
state_dict['block.norm.num_batches_tracked'], assert_tensor_equal(
wrapped_model.module.block.module.norm.num_batches_tracked) state_dict['block.norm.num_batches_tracked'],
wrapped_model.module.block.module.norm.num_batches_tracked)
assert_tensor_equal(state_dict['conv.weight'], assert_tensor_equal(state_dict['conv.weight'],
wrapped_model.module.conv.module.weight) wrapped_model.module.conv.module.weight)
assert_tensor_equal(state_dict['conv.bias'], assert_tensor_equal(state_dict['conv.bias'],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment