Unverified Commit 45fa3e44 authored by Zaida Zhou's avatar Zaida Zhou Committed by GitHub
Browse files

Add pyupgrade pre-commit hook (#1937)

* add pyupgrade

* add options for pyupgrade

* minor refinement
parent c561264d
...@@ -30,7 +30,7 @@ class BasicBlock(nn.Module): ...@@ -30,7 +30,7 @@ class BasicBlock(nn.Module):
downsample=None, downsample=None,
style='pytorch', style='pytorch',
with_cp=False): with_cp=False):
super(BasicBlock, self).__init__() super().__init__()
assert style in ['pytorch', 'caffe'] assert style in ['pytorch', 'caffe']
self.conv1 = conv3x3(inplanes, planes, stride, dilation) self.conv1 = conv3x3(inplanes, planes, stride, dilation)
self.bn1 = nn.BatchNorm2d(planes) self.bn1 = nn.BatchNorm2d(planes)
...@@ -77,7 +77,7 @@ class Bottleneck(nn.Module): ...@@ -77,7 +77,7 @@ class Bottleneck(nn.Module):
If style is "pytorch", the stride-two layer is the 3x3 conv layer, if If style is "pytorch", the stride-two layer is the 3x3 conv layer, if
it is "caffe", the stride-two layer is the first 1x1 conv layer. it is "caffe", the stride-two layer is the first 1x1 conv layer.
""" """
super(Bottleneck, self).__init__() super().__init__()
assert style in ['pytorch', 'caffe'] assert style in ['pytorch', 'caffe']
if style == 'pytorch': if style == 'pytorch':
conv1_stride = 1 conv1_stride = 1
...@@ -218,7 +218,7 @@ class ResNet(nn.Module): ...@@ -218,7 +218,7 @@ class ResNet(nn.Module):
bn_eval=True, bn_eval=True,
bn_frozen=False, bn_frozen=False,
with_cp=False): with_cp=False):
super(ResNet, self).__init__() super().__init__()
if depth not in self.arch_settings: if depth not in self.arch_settings:
raise KeyError(f'invalid depth {depth} for resnet') raise KeyError(f'invalid depth {depth} for resnet')
assert num_stages >= 1 and num_stages <= 4 assert num_stages >= 1 and num_stages <= 4
...@@ -293,7 +293,7 @@ class ResNet(nn.Module): ...@@ -293,7 +293,7 @@ class ResNet(nn.Module):
return tuple(outs) return tuple(outs)
def train(self, mode=True): def train(self, mode=True):
super(ResNet, self).train(mode) super().train(mode)
if self.bn_eval: if self.bn_eval:
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.BatchNorm2d): if isinstance(m, nn.BatchNorm2d):
......
...@@ -277,10 +277,10 @@ def print_model_with_flops(model, ...@@ -277,10 +277,10 @@ def print_model_with_flops(model,
return ', '.join([ return ', '.join([
params_to_string( params_to_string(
accumulated_num_params, units='M', precision=precision), accumulated_num_params, units='M', precision=precision),
'{:.3%} Params'.format(accumulated_num_params / total_params), f'{accumulated_num_params / total_params:.3%} Params',
flops_to_string( flops_to_string(
accumulated_flops_cost, units=units, precision=precision), accumulated_flops_cost, units=units, precision=precision),
'{:.3%} FLOPs'.format(accumulated_flops_cost / total_flops), f'{accumulated_flops_cost / total_flops:.3%} FLOPs',
self.original_extra_repr() self.original_extra_repr()
]) ])
......
...@@ -129,7 +129,7 @@ def _get_bases_name(m): ...@@ -129,7 +129,7 @@ def _get_bases_name(m):
return [b.__name__ for b in m.__class__.__bases__] return [b.__name__ for b in m.__class__.__bases__]
class BaseInit(object): class BaseInit:
def __init__(self, *, bias=0, bias_prob=None, layer=None): def __init__(self, *, bias=0, bias_prob=None, layer=None):
self.wholemodule = False self.wholemodule = False
...@@ -461,7 +461,7 @@ class Caffe2XavierInit(KaimingInit): ...@@ -461,7 +461,7 @@ class Caffe2XavierInit(KaimingInit):
@INITIALIZERS.register_module(name='Pretrained') @INITIALIZERS.register_module(name='Pretrained')
class PretrainedInit(object): class PretrainedInit:
"""Initialize module by loading a pretrained model. """Initialize module by loading a pretrained model.
Args: Args:
......
...@@ -70,7 +70,7 @@ class VGG(nn.Module): ...@@ -70,7 +70,7 @@ class VGG(nn.Module):
bn_frozen=False, bn_frozen=False,
ceil_mode=False, ceil_mode=False,
with_last_pool=True): with_last_pool=True):
super(VGG, self).__init__() super().__init__()
if depth not in self.arch_settings: if depth not in self.arch_settings:
raise KeyError(f'invalid depth {depth} for vgg') raise KeyError(f'invalid depth {depth} for vgg')
assert num_stages >= 1 and num_stages <= 5 assert num_stages >= 1 and num_stages <= 5
...@@ -157,7 +157,7 @@ class VGG(nn.Module): ...@@ -157,7 +157,7 @@ class VGG(nn.Module):
return tuple(outs) return tuple(outs)
def train(self, mode=True): def train(self, mode=True):
super(VGG, self).train(mode) super().train(mode)
if self.bn_eval: if self.bn_eval:
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.BatchNorm2d): if isinstance(m, nn.BatchNorm2d):
......
...@@ -33,7 +33,7 @@ class MLUDataParallel(MMDataParallel): ...@@ -33,7 +33,7 @@ class MLUDataParallel(MMDataParallel):
""" """
def __init__(self, *args, dim=0, **kwargs): def __init__(self, *args, dim=0, **kwargs):
super(MLUDataParallel, self).__init__(*args, dim=dim, **kwargs) super().__init__(*args, dim=dim, **kwargs)
self.device_ids = [0] self.device_ids = [0]
self.src_device_obj = torch.device('mlu:0') self.src_device_obj = torch.device('mlu:0')
......
...@@ -210,9 +210,9 @@ class PetrelBackend(BaseStorageBackend): ...@@ -210,9 +210,9 @@ class PetrelBackend(BaseStorageBackend):
""" """
if not has_method(self._client, 'delete'): if not has_method(self._client, 'delete'):
raise NotImplementedError( raise NotImplementedError(
('Current version of Petrel Python SDK has not supported ' 'Current version of Petrel Python SDK has not supported '
'the `delete` method, please use a higher version or dev' 'the `delete` method, please use a higher version or dev'
' branch instead.')) ' branch instead.')
filepath = self._map_path(filepath) filepath = self._map_path(filepath)
filepath = self._format_path(filepath) filepath = self._format_path(filepath)
...@@ -230,9 +230,9 @@ class PetrelBackend(BaseStorageBackend): ...@@ -230,9 +230,9 @@ class PetrelBackend(BaseStorageBackend):
if not (has_method(self._client, 'contains') if not (has_method(self._client, 'contains')
and has_method(self._client, 'isdir')): and has_method(self._client, 'isdir')):
raise NotImplementedError( raise NotImplementedError(
('Current version of Petrel Python SDK has not supported ' 'Current version of Petrel Python SDK has not supported '
'the `contains` and `isdir` methods, please use a higher' 'the `contains` and `isdir` methods, please use a higher'
'version or dev branch instead.')) 'version or dev branch instead.')
filepath = self._map_path(filepath) filepath = self._map_path(filepath)
filepath = self._format_path(filepath) filepath = self._format_path(filepath)
...@@ -251,9 +251,9 @@ class PetrelBackend(BaseStorageBackend): ...@@ -251,9 +251,9 @@ class PetrelBackend(BaseStorageBackend):
""" """
if not has_method(self._client, 'isdir'): if not has_method(self._client, 'isdir'):
raise NotImplementedError( raise NotImplementedError(
('Current version of Petrel Python SDK has not supported ' 'Current version of Petrel Python SDK has not supported '
'the `isdir` method, please use a higher version or dev' 'the `isdir` method, please use a higher version or dev'
' branch instead.')) ' branch instead.')
filepath = self._map_path(filepath) filepath = self._map_path(filepath)
filepath = self._format_path(filepath) filepath = self._format_path(filepath)
...@@ -271,9 +271,9 @@ class PetrelBackend(BaseStorageBackend): ...@@ -271,9 +271,9 @@ class PetrelBackend(BaseStorageBackend):
""" """
if not has_method(self._client, 'contains'): if not has_method(self._client, 'contains'):
raise NotImplementedError( raise NotImplementedError(
('Current version of Petrel Python SDK has not supported ' 'Current version of Petrel Python SDK has not supported '
'the `contains` method, please use a higher version or ' 'the `contains` method, please use a higher version or '
'dev branch instead.')) 'dev branch instead.')
filepath = self._map_path(filepath) filepath = self._map_path(filepath)
filepath = self._format_path(filepath) filepath = self._format_path(filepath)
...@@ -366,9 +366,9 @@ class PetrelBackend(BaseStorageBackend): ...@@ -366,9 +366,9 @@ class PetrelBackend(BaseStorageBackend):
""" """
if not has_method(self._client, 'list'): if not has_method(self._client, 'list'):
raise NotImplementedError( raise NotImplementedError(
('Current version of Petrel Python SDK has not supported ' 'Current version of Petrel Python SDK has not supported '
'the `list` method, please use a higher version or dev' 'the `list` method, please use a higher version or dev'
' branch instead.')) ' branch instead.')
dir_path = self._map_path(dir_path) dir_path = self._map_path(dir_path)
dir_path = self._format_path(dir_path) dir_path = self._format_path(dir_path)
...@@ -549,7 +549,7 @@ class HardDiskBackend(BaseStorageBackend): ...@@ -549,7 +549,7 @@ class HardDiskBackend(BaseStorageBackend):
Returns: Returns:
str: Expected text reading from ``filepath``. str: Expected text reading from ``filepath``.
""" """
with open(filepath, 'r', encoding=encoding) as f: with open(filepath, encoding=encoding) as f:
value_buf = f.read() value_buf = f.read()
return value_buf return value_buf
......
...@@ -12,8 +12,7 @@ class PickleHandler(BaseFileHandler): ...@@ -12,8 +12,7 @@ class PickleHandler(BaseFileHandler):
return pickle.load(file, **kwargs) return pickle.load(file, **kwargs)
def load_from_path(self, filepath, **kwargs): def load_from_path(self, filepath, **kwargs):
return super(PickleHandler, self).load_from_path( return super().load_from_path(filepath, mode='rb', **kwargs)
filepath, mode='rb', **kwargs)
def dump_to_str(self, obj, **kwargs): def dump_to_str(self, obj, **kwargs):
kwargs.setdefault('protocol', 2) kwargs.setdefault('protocol', 2)
...@@ -24,5 +23,4 @@ class PickleHandler(BaseFileHandler): ...@@ -24,5 +23,4 @@ class PickleHandler(BaseFileHandler):
pickle.dump(obj, file, **kwargs) pickle.dump(obj, file, **kwargs)
def dump_to_path(self, obj, filepath, **kwargs): def dump_to_path(self, obj, filepath, **kwargs):
super(PickleHandler, self).dump_to_path( super().dump_to_path(obj, filepath, mode='wb', **kwargs)
obj, filepath, mode='wb', **kwargs)
...@@ -157,7 +157,7 @@ def imresize_to_multiple(img, ...@@ -157,7 +157,7 @@ def imresize_to_multiple(img,
size = _scale_size((w, h), scale_factor) size = _scale_size((w, h), scale_factor)
divisor = to_2tuple(divisor) divisor = to_2tuple(divisor)
size = tuple([int(np.ceil(s / d)) * d for s, d in zip(size, divisor)]) size = tuple(int(np.ceil(s / d)) * d for s, d in zip(size, divisor))
resized_img, w_scale, h_scale = imresize( resized_img, w_scale, h_scale = imresize(
img, img,
size, size,
......
...@@ -59,7 +59,7 @@ def _parse_arg(value, desc): ...@@ -59,7 +59,7 @@ def _parse_arg(value, desc):
raise RuntimeError( raise RuntimeError(
"ONNX symbolic doesn't know to interpret ListConstruct node") "ONNX symbolic doesn't know to interpret ListConstruct node")
raise RuntimeError('Unexpected node type: {}'.format(value.node().kind())) raise RuntimeError(f'Unexpected node type: {value.node().kind()}')
def _maybe_get_const(value, desc): def _maybe_get_const(value, desc):
......
...@@ -86,7 +86,7 @@ class BorderAlign(nn.Module): ...@@ -86,7 +86,7 @@ class BorderAlign(nn.Module):
""" """
def __init__(self, pool_size): def __init__(self, pool_size):
super(BorderAlign, self).__init__() super().__init__()
self.pool_size = pool_size self.pool_size = pool_size
def forward(self, input, boxes): def forward(self, input, boxes):
......
...@@ -131,7 +131,7 @@ def box_iou_rotated(bboxes1, ...@@ -131,7 +131,7 @@ def box_iou_rotated(bboxes1,
if aligned: if aligned:
ious = bboxes1.new_zeros(rows) ious = bboxes1.new_zeros(rows)
else: else:
ious = bboxes1.new_zeros((rows * cols)) ious = bboxes1.new_zeros(rows * cols)
if not clockwise: if not clockwise:
flip_mat = bboxes1.new_ones(bboxes1.shape[-1]) flip_mat = bboxes1.new_ones(bboxes1.shape[-1])
flip_mat[-1] = -1 flip_mat[-1] = -1
......
...@@ -85,7 +85,7 @@ carafe_naive = CARAFENaiveFunction.apply ...@@ -85,7 +85,7 @@ carafe_naive = CARAFENaiveFunction.apply
class CARAFENaive(Module): class CARAFENaive(Module):
def __init__(self, kernel_size, group_size, scale_factor): def __init__(self, kernel_size, group_size, scale_factor):
super(CARAFENaive, self).__init__() super().__init__()
assert isinstance(kernel_size, int) and isinstance( assert isinstance(kernel_size, int) and isinstance(
group_size, int) and isinstance(scale_factor, int) group_size, int) and isinstance(scale_factor, int)
...@@ -195,7 +195,7 @@ class CARAFE(Module): ...@@ -195,7 +195,7 @@ class CARAFE(Module):
""" """
def __init__(self, kernel_size, group_size, scale_factor): def __init__(self, kernel_size, group_size, scale_factor):
super(CARAFE, self).__init__() super().__init__()
assert isinstance(kernel_size, int) and isinstance( assert isinstance(kernel_size, int) and isinstance(
group_size, int) and isinstance(scale_factor, int) group_size, int) and isinstance(scale_factor, int)
...@@ -238,7 +238,7 @@ class CARAFEPack(nn.Module): ...@@ -238,7 +238,7 @@ class CARAFEPack(nn.Module):
encoder_kernel=3, encoder_kernel=3,
encoder_dilation=1, encoder_dilation=1,
compressed_channels=64): compressed_channels=64):
super(CARAFEPack, self).__init__() super().__init__()
self.channels = channels self.channels = channels
self.scale_factor = scale_factor self.scale_factor = scale_factor
self.up_kernel = up_kernel self.up_kernel = up_kernel
......
...@@ -125,7 +125,7 @@ class CornerPool(nn.Module): ...@@ -125,7 +125,7 @@ class CornerPool(nn.Module):
} }
def __init__(self, mode): def __init__(self, mode):
super(CornerPool, self).__init__() super().__init__()
assert mode in self.pool_functions assert mode in self.pool_functions
self.mode = mode self.mode = mode
self.corner_pool = self.pool_functions[mode] self.corner_pool = self.pool_functions[mode]
......
...@@ -236,7 +236,7 @@ class DeformConv2d(nn.Module): ...@@ -236,7 +236,7 @@ class DeformConv2d(nn.Module):
deform_groups: int = 1, deform_groups: int = 1,
bias: bool = False, bias: bool = False,
im2col_step: int = 32) -> None: im2col_step: int = 32) -> None:
super(DeformConv2d, self).__init__() super().__init__()
assert not bias, \ assert not bias, \
f'bias={bias} is not supported in DeformConv2d.' f'bias={bias} is not supported in DeformConv2d.'
...@@ -356,7 +356,7 @@ class DeformConv2dPack(DeformConv2d): ...@@ -356,7 +356,7 @@ class DeformConv2dPack(DeformConv2d):
_version = 2 _version = 2
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(DeformConv2dPack, self).__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.conv_offset = nn.Conv2d( self.conv_offset = nn.Conv2d(
self.in_channels, self.in_channels,
self.deform_groups * 2 * self.kernel_size[0] * self.kernel_size[1], self.deform_groups * 2 * self.kernel_size[0] * self.kernel_size[1],
......
...@@ -96,7 +96,7 @@ class DeformRoIPool(nn.Module): ...@@ -96,7 +96,7 @@ class DeformRoIPool(nn.Module):
spatial_scale=1.0, spatial_scale=1.0,
sampling_ratio=0, sampling_ratio=0,
gamma=0.1): gamma=0.1):
super(DeformRoIPool, self).__init__() super().__init__()
self.output_size = _pair(output_size) self.output_size = _pair(output_size)
self.spatial_scale = float(spatial_scale) self.spatial_scale = float(spatial_scale)
self.sampling_ratio = int(sampling_ratio) self.sampling_ratio = int(sampling_ratio)
...@@ -117,8 +117,7 @@ class DeformRoIPoolPack(DeformRoIPool): ...@@ -117,8 +117,7 @@ class DeformRoIPoolPack(DeformRoIPool):
spatial_scale=1.0, spatial_scale=1.0,
sampling_ratio=0, sampling_ratio=0,
gamma=0.1): gamma=0.1):
super(DeformRoIPoolPack, self).__init__(output_size, spatial_scale, super().__init__(output_size, spatial_scale, sampling_ratio, gamma)
sampling_ratio, gamma)
self.output_channels = output_channels self.output_channels = output_channels
self.deform_fc_channels = deform_fc_channels self.deform_fc_channels = deform_fc_channels
...@@ -158,8 +157,7 @@ class ModulatedDeformRoIPoolPack(DeformRoIPool): ...@@ -158,8 +157,7 @@ class ModulatedDeformRoIPoolPack(DeformRoIPool):
spatial_scale=1.0, spatial_scale=1.0,
sampling_ratio=0, sampling_ratio=0,
gamma=0.1): gamma=0.1):
super(ModulatedDeformRoIPoolPack, super().__init__(output_size, spatial_scale, sampling_ratio, gamma)
self).__init__(output_size, spatial_scale, sampling_ratio, gamma)
self.output_channels = output_channels self.output_channels = output_channels
self.deform_fc_channels = deform_fc_channels self.deform_fc_channels = deform_fc_channels
......
...@@ -89,7 +89,7 @@ sigmoid_focal_loss = SigmoidFocalLossFunction.apply ...@@ -89,7 +89,7 @@ sigmoid_focal_loss = SigmoidFocalLossFunction.apply
class SigmoidFocalLoss(nn.Module): class SigmoidFocalLoss(nn.Module):
def __init__(self, gamma, alpha, weight=None, reduction='mean'): def __init__(self, gamma, alpha, weight=None, reduction='mean'):
super(SigmoidFocalLoss, self).__init__() super().__init__()
self.gamma = gamma self.gamma = gamma
self.alpha = alpha self.alpha = alpha
self.register_buffer('weight', weight) self.register_buffer('weight', weight)
...@@ -195,7 +195,7 @@ softmax_focal_loss = SoftmaxFocalLossFunction.apply ...@@ -195,7 +195,7 @@ softmax_focal_loss = SoftmaxFocalLossFunction.apply
class SoftmaxFocalLoss(nn.Module): class SoftmaxFocalLoss(nn.Module):
def __init__(self, gamma, alpha, weight=None, reduction='mean'): def __init__(self, gamma, alpha, weight=None, reduction='mean'):
super(SoftmaxFocalLoss, self).__init__() super().__init__()
self.gamma = gamma self.gamma = gamma
self.alpha = alpha self.alpha = alpha
self.register_buffer('weight', weight) self.register_buffer('weight', weight)
......
...@@ -212,7 +212,7 @@ class FusedBiasLeakyReLU(nn.Module): ...@@ -212,7 +212,7 @@ class FusedBiasLeakyReLU(nn.Module):
""" """
def __init__(self, num_channels, negative_slope=0.2, scale=2**0.5): def __init__(self, num_channels, negative_slope=0.2, scale=2**0.5):
super(FusedBiasLeakyReLU, self).__init__() super().__init__()
self.bias = nn.Parameter(torch.zeros(num_channels)) self.bias = nn.Parameter(torch.zeros(num_channels))
self.negative_slope = negative_slope self.negative_slope = negative_slope
......
...@@ -98,13 +98,12 @@ class MaskedConv2d(nn.Conv2d): ...@@ -98,13 +98,12 @@ class MaskedConv2d(nn.Conv2d):
dilation=1, dilation=1,
groups=1, groups=1,
bias=True): bias=True):
super(MaskedConv2d, super().__init__(in_channels, out_channels, kernel_size, stride,
self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
padding, dilation, groups, bias)
def forward(self, input, mask=None): def forward(self, input, mask=None):
if mask is None: # fallback to the normal Conv2d if mask is None: # fallback to the normal Conv2d
return super(MaskedConv2d, self).forward(input) return super().forward(input)
else: else:
return masked_conv2d(input, mask, self.weight, self.bias, return masked_conv2d(input, mask, self.weight, self.bias,
self.padding) self.padding)
...@@ -53,7 +53,7 @@ class BaseMergeCell(nn.Module): ...@@ -53,7 +53,7 @@ class BaseMergeCell(nn.Module):
input_conv_cfg=None, input_conv_cfg=None,
input_norm_cfg=None, input_norm_cfg=None,
upsample_mode='nearest'): upsample_mode='nearest'):
super(BaseMergeCell, self).__init__() super().__init__()
assert upsample_mode in ['nearest', 'bilinear'] assert upsample_mode in ['nearest', 'bilinear']
self.with_out_conv = with_out_conv self.with_out_conv = with_out_conv
self.with_input1_conv = with_input1_conv self.with_input1_conv = with_input1_conv
...@@ -121,7 +121,7 @@ class BaseMergeCell(nn.Module): ...@@ -121,7 +121,7 @@ class BaseMergeCell(nn.Module):
class SumCell(BaseMergeCell): class SumCell(BaseMergeCell):
def __init__(self, in_channels, out_channels, **kwargs): def __init__(self, in_channels, out_channels, **kwargs):
super(SumCell, self).__init__(in_channels, out_channels, **kwargs) super().__init__(in_channels, out_channels, **kwargs)
def _binary_op(self, x1, x2): def _binary_op(self, x1, x2):
return x1 + x2 return x1 + x2
...@@ -130,8 +130,7 @@ class SumCell(BaseMergeCell): ...@@ -130,8 +130,7 @@ class SumCell(BaseMergeCell):
class ConcatCell(BaseMergeCell): class ConcatCell(BaseMergeCell):
def __init__(self, in_channels, out_channels, **kwargs): def __init__(self, in_channels, out_channels, **kwargs):
super(ConcatCell, self).__init__(in_channels * 2, out_channels, super().__init__(in_channels * 2, out_channels, **kwargs)
**kwargs)
def _binary_op(self, x1, x2): def _binary_op(self, x1, x2):
ret = torch.cat([x1, x2], dim=1) ret = torch.cat([x1, x2], dim=1)
......
...@@ -168,7 +168,7 @@ class ModulatedDeformConv2d(nn.Module): ...@@ -168,7 +168,7 @@ class ModulatedDeformConv2d(nn.Module):
groups=1, groups=1,
deform_groups=1, deform_groups=1,
bias=True): bias=True):
super(ModulatedDeformConv2d, self).__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
self.kernel_size = _pair(kernel_size) self.kernel_size = _pair(kernel_size)
...@@ -227,7 +227,7 @@ class ModulatedDeformConv2dPack(ModulatedDeformConv2d): ...@@ -227,7 +227,7 @@ class ModulatedDeformConv2dPack(ModulatedDeformConv2d):
_version = 2 _version = 2
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(ModulatedDeformConv2dPack, self).__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.conv_offset = nn.Conv2d( self.conv_offset = nn.Conv2d(
self.in_channels, self.in_channels,
self.deform_groups * 3 * self.kernel_size[0] * self.kernel_size[1], self.deform_groups * 3 * self.kernel_size[0] * self.kernel_size[1],
...@@ -239,7 +239,7 @@ class ModulatedDeformConv2dPack(ModulatedDeformConv2d): ...@@ -239,7 +239,7 @@ class ModulatedDeformConv2dPack(ModulatedDeformConv2d):
self.init_weights() self.init_weights()
def init_weights(self): def init_weights(self):
super(ModulatedDeformConv2dPack, self).init_weights() super().init_weights()
if hasattr(self, 'conv_offset'): if hasattr(self, 'conv_offset'):
self.conv_offset.weight.data.zero_() self.conv_offset.weight.data.zero_()
self.conv_offset.bias.data.zero_() self.conv_offset.bias.data.zero_()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment