Unverified Commit 8e3a8015 authored by BigBigDream's avatar BigBigDream Committed by GitHub
Browse files

fix mmcv ci for parrots (#782)

* fix mmcv ci for parrots

* fix mmcv ci

* fix lint
parent f169fb52
......@@ -4,7 +4,6 @@ from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair
from ..onnx import is_custom_op_loaded
from ..utils import deprecated_api_warning, ext_loader
ext_module = ext_loader.load_ext('_ext',
......@@ -16,6 +15,7 @@ class RoIAlignFunction(Function):
@staticmethod
def symbolic(g, input, rois, output_size, spatial_scale, sampling_ratio,
pool_mode, aligned):
from ..onnx import is_custom_op_loaded
has_custom_op = is_custom_op_loaded()
if has_custom_op:
return g.op(
......
......@@ -3,6 +3,7 @@ from collections import OrderedDict
from unittest.mock import MagicMock
import pytest
import torch
import torch.nn as nn
from torch.nn.parallel import DataParallel
......@@ -47,11 +48,19 @@ def assert_tensor_equal(tensor_a, tensor_b):
def test_get_state_dict():
state_dict_keys = set([
'block.conv.weight', 'block.conv.bias', 'block.norm.weight',
'block.norm.bias', 'block.norm.running_mean', 'block.norm.running_var',
'block.norm.num_batches_tracked', 'conv.weight', 'conv.bias'
])
if torch.__version__ == 'parrots':
state_dict_keys = set([
'block.conv.weight', 'block.conv.bias', 'block.norm.weight',
'block.norm.bias', 'block.norm.running_mean',
'block.norm.running_var', 'conv.weight', 'conv.bias'
])
else:
state_dict_keys = set([
'block.conv.weight', 'block.conv.bias', 'block.norm.weight',
'block.norm.bias', 'block.norm.running_mean',
'block.norm.running_var', 'block.norm.num_batches_tracked',
'conv.weight', 'conv.bias'
])
model = Model()
state_dict = get_state_dict(model)
......@@ -68,8 +77,9 @@ def test_get_state_dict():
model.block.norm.running_mean)
assert_tensor_equal(state_dict['block.norm.running_var'],
model.block.norm.running_var)
assert_tensor_equal(state_dict['block.norm.num_batches_tracked'],
model.block.norm.num_batches_tracked)
if torch.__version__ != 'parrots':
assert_tensor_equal(state_dict['block.norm.num_batches_tracked'],
model.block.norm.num_batches_tracked)
assert_tensor_equal(state_dict['conv.weight'], model.conv.weight)
assert_tensor_equal(state_dict['conv.bias'], model.conv.bias)
......@@ -89,8 +99,10 @@ def test_get_state_dict():
wrapped_model.module.block.norm.running_mean)
assert_tensor_equal(state_dict['block.norm.running_var'],
wrapped_model.module.block.norm.running_var)
assert_tensor_equal(state_dict['block.norm.num_batches_tracked'],
wrapped_model.module.block.norm.num_batches_tracked)
if torch.__version__ != 'parrots':
assert_tensor_equal(
state_dict['block.norm.num_batches_tracked'],
wrapped_model.module.block.norm.num_batches_tracked)
assert_tensor_equal(state_dict['conv.weight'],
wrapped_model.module.conv.weight)
assert_tensor_equal(state_dict['conv.bias'],
......@@ -115,9 +127,10 @@ def test_get_state_dict():
wrapped_model.module.block.module.norm.running_mean)
assert_tensor_equal(state_dict['block.norm.running_var'],
wrapped_model.module.block.module.norm.running_var)
assert_tensor_equal(
state_dict['block.norm.num_batches_tracked'],
wrapped_model.module.block.module.norm.num_batches_tracked)
if torch.__version__ != 'parrots':
assert_tensor_equal(
state_dict['block.norm.num_batches_tracked'],
wrapped_model.module.block.module.norm.num_batches_tracked)
assert_tensor_equal(state_dict['conv.weight'],
wrapped_model.module.conv.module.weight)
assert_tensor_equal(state_dict['conv.bias'],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment