Commit 41b18fd8 authored by zhe chen's avatar zhe chen
Browse files

Use pre-commit to reformat code


Use pre-commit to reformat code
parent ff20ea39
...@@ -9,17 +9,16 @@ import torch.nn.functional as F ...@@ -9,17 +9,16 @@ import torch.nn.functional as F
import torch.utils.checkpoint as cp import torch.utils.checkpoint as cp
from mmcv.cnn import (Linear, build_activation_layer, build_conv_layer, from mmcv.cnn import (Linear, build_activation_layer, build_conv_layer,
build_norm_layer, xavier_init) build_norm_layer, xavier_init)
from mmcv.cnn.bricks.registry import (TRANSFORMER_LAYER,
TRANSFORMER_LAYER_SEQUENCE,
FEEDFORWARD_NETWORK)
from mmcv.cnn.bricks.drop import build_dropout from mmcv.cnn.bricks.drop import build_dropout
from mmcv.cnn.bricks.registry import (FEEDFORWARD_NETWORK, TRANSFORMER_LAYER,
TRANSFORMER_LAYER_SEQUENCE)
from mmcv.cnn.bricks.transformer import (BaseTransformerLayer, from mmcv.cnn.bricks.transformer import (BaseTransformerLayer,
TransformerLayerSequence, TransformerLayerSequence,
build_transformer_layer_sequence,
build_attention, build_attention,
build_feedforward_network) build_feedforward_network,
build_transformer_layer_sequence)
from mmcv.runner.base_module import BaseModule, ModuleList, Sequential from mmcv.runner.base_module import BaseModule, ModuleList, Sequential
from mmcv.utils import to_2tuple, ConfigDict, deprecated_api_warning from mmcv.utils import ConfigDict, deprecated_api_warning, to_2tuple
from torch.nn.init import normal_ from torch.nn.init import normal_
from ..builder import TRANSFORMER from ..builder import TRANSFORMER
...@@ -319,12 +318,12 @@ class FFN(BaseModule): ...@@ -319,12 +318,12 @@ class FFN(BaseModule):
"""Forward function for `FFN`. """Forward function for `FFN`.
The function would add x to the output tensor if residue is None. The function would add x to the output tensor if residue is None.
""" """
if self.with_cp and x.requires_grad: if self.with_cp and x.requires_grad:
out = cp.checkpoint(self.layers, x) out = cp.checkpoint(self.layers, x)
else: else:
out = self.layers(x) out = self.layers(x)
if not self.add_identity: if not self.add_identity:
return self.dropout_layer(out) return self.dropout_layer(out)
if identity is None: if identity is None:
......
...@@ -4,16 +4,14 @@ ...@@ -4,16 +4,14 @@
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
# -------------------------------------------------------- # --------------------------------------------------------
from __future__ import absolute_import from __future__ import absolute_import, division, print_function
from __future__ import print_function
from __future__ import division
import DCNv3
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.autograd import Function from torch.autograd import Function
from torch.autograd.function import once_differentiable from torch.autograd.function import once_differentiable
from torch.cuda.amp import custom_bwd, custom_fwd from torch.cuda.amp import custom_bwd, custom_fwd
import DCNv3
class DCNv3Function(Function): class DCNv3Function(Function):
...@@ -88,6 +86,7 @@ class DCNv3Function(Function): ...@@ -88,6 +86,7 @@ class DCNv3Function(Function):
im2col_step_i=int(im2col_step), im2col_step_i=int(im2col_step),
) )
def _get_reference_points(spatial_shapes, device, kernel_h, kernel_w, dilation_h, dilation_w, pad_h=0, pad_w=0, stride_h=1, stride_w=1): def _get_reference_points(spatial_shapes, device, kernel_h, kernel_w, dilation_h, dilation_w, pad_h=0, pad_w=0, stride_h=1, stride_w=1):
_, H_, W_, _ = spatial_shapes _, H_, W_, _ = spatial_shapes
H_out = (H_ - (dilation_h * (kernel_h - 1) + 1)) // stride_h + 1 H_out = (H_ - (dilation_h * (kernel_h - 1) + 1)) // stride_h + 1
......
...@@ -4,4 +4,4 @@ ...@@ -4,4 +4,4 @@
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
# -------------------------------------------------------- # --------------------------------------------------------
from .dcnv3 import DCNv3, DCNv3_pytorch from .dcnv3 import DCNv3, DCNv3_pytorch
\ No newline at end of file
...@@ -4,22 +4,24 @@ ...@@ -4,22 +4,24 @@
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
# -------------------------------------------------------- # --------------------------------------------------------
from __future__ import absolute_import from __future__ import absolute_import, division, print_function
from __future__ import print_function
from __future__ import division
import warnings import warnings
import torch import torch
from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.init import xavier_uniform_, constant_ from torch import nn
from torch.nn.init import constant_, xavier_uniform_
from ..functions import DCNv3Function, dcnv3_core_pytorch from ..functions import DCNv3Function, dcnv3_core_pytorch
try: try:
from DCNv4.functions import DCNv4Function from DCNv4.functions import DCNv4Function
except: except:
warnings.warn('Now, we support DCNv4 in InternImage.') warnings.warn('Now, we support DCNv4 in InternImage.')
import math import math
class to_channels_first(nn.Module): class to_channels_first(nn.Module):
def __init__(self): def __init__(self):
...@@ -76,7 +78,7 @@ def build_act_layer(act_layer): ...@@ -76,7 +78,7 @@ def build_act_layer(act_layer):
def _is_power_of_2(n): def _is_power_of_2(n):
if (not isinstance(n, int)) or (n < 0): if (not isinstance(n, int)) or (n < 0):
raise ValueError( raise ValueError(
"invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) 'invalid input for _is_power_of_2: {} (type: {})'.format(n, type(n)))
return (n & (n - 1) == 0) and n != 0 return (n & (n - 1) == 0) and n != 0
...@@ -128,7 +130,7 @@ class DCNv3_pytorch(nn.Module): ...@@ -128,7 +130,7 @@ class DCNv3_pytorch(nn.Module):
if not _is_power_of_2(_d_per_group): if not _is_power_of_2(_d_per_group):
warnings.warn( warnings.warn(
"You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 " "You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 "
"which is more efficient in our CUDA implementation.") 'which is more efficient in our CUDA implementation.')
self.offset_scale = offset_scale self.offset_scale = offset_scale
self.channels = channels self.channels = channels
...@@ -165,7 +167,7 @@ class DCNv3_pytorch(nn.Module): ...@@ -165,7 +167,7 @@ class DCNv3_pytorch(nn.Module):
self.input_proj = nn.Linear(channels, channels) self.input_proj = nn.Linear(channels, channels)
self.output_proj = nn.Linear(channels, channels) self.output_proj = nn.Linear(channels, channels)
self._reset_parameters() self._reset_parameters()
if center_feature_scale: if center_feature_scale:
self.center_feature_scale_proj_weight = nn.Parameter( self.center_feature_scale_proj_weight = nn.Parameter(
torch.zeros((group, channels), dtype=torch.float)) torch.zeros((group, channels), dtype=torch.float))
...@@ -234,7 +236,7 @@ class DCNv3(nn.Module): ...@@ -234,7 +236,7 @@ class DCNv3(nn.Module):
norm_layer='LN', norm_layer='LN',
center_feature_scale=False, center_feature_scale=False,
use_dcn_v4_op=False, use_dcn_v4_op=False,
): ):
""" """
DCNv3 Module DCNv3 Module
:param channels :param channels
...@@ -257,7 +259,7 @@ class DCNv3(nn.Module): ...@@ -257,7 +259,7 @@ class DCNv3(nn.Module):
if not _is_power_of_2(_d_per_group): if not _is_power_of_2(_d_per_group):
warnings.warn( warnings.warn(
"You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 " "You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 "
"which is more efficient in our CUDA implementation.") 'which is more efficient in our CUDA implementation.')
self.offset_scale = offset_scale self.offset_scale = offset_scale
self.channels = channels self.channels = channels
...@@ -270,7 +272,7 @@ class DCNv3(nn.Module): ...@@ -270,7 +272,7 @@ class DCNv3(nn.Module):
self.group_channels = channels // group self.group_channels = channels // group
self.offset_scale = offset_scale self.offset_scale = offset_scale
self.center_feature_scale = center_feature_scale self.center_feature_scale = center_feature_scale
self.use_dcn_v4_op = use_dcn_v4_op self.use_dcn_v4_op = use_dcn_v4_op
self.dw_conv = nn.Sequential( self.dw_conv = nn.Sequential(
...@@ -296,7 +298,7 @@ class DCNv3(nn.Module): ...@@ -296,7 +298,7 @@ class DCNv3(nn.Module):
self.input_proj = nn.Linear(channels, channels) self.input_proj = nn.Linear(channels, channels)
self.output_proj = nn.Linear(channels, channels) self.output_proj = nn.Linear(channels, channels)
self._reset_parameters() self._reset_parameters()
if center_feature_scale: if center_feature_scale:
self.center_feature_scale_proj_weight = nn.Parameter( self.center_feature_scale_proj_weight = nn.Parameter(
torch.zeros((group, channels), dtype=torch.float)) torch.zeros((group, channels), dtype=torch.float))
...@@ -329,7 +331,7 @@ class DCNv3(nn.Module): ...@@ -329,7 +331,7 @@ class DCNv3(nn.Module):
x1 = self.dw_conv(x1) x1 = self.dw_conv(x1)
offset = self.offset(x1) offset = self.offset(x1)
mask = self.mask(x1).reshape(N, H, W, self.group, -1) mask = self.mask(x1).reshape(N, H, W, self.group, -1)
if not self.use_dcn_v4_op: if not self.use_dcn_v4_op:
mask = F.softmax(mask, -1).reshape(N, H, W, -1).type(dtype) mask = F.softmax(mask, -1).reshape(N, H, W, -1).type(dtype)
x = DCNv3Function.apply( x = DCNv3Function.apply(
...@@ -349,12 +351,12 @@ class DCNv3(nn.Module): ...@@ -349,12 +351,12 @@ class DCNv3(nn.Module):
mask = mask.view(N, H, W, self.group, -1) mask = mask.view(N, H, W, self.group, -1)
offset_mask = torch.cat([offset, mask], -1).view(N, H, W, -1).contiguous() offset_mask = torch.cat([offset, mask], -1).view(N, H, W, -1).contiguous()
# For efficiency, the last dimension of the offset_mask tensor in dcnv4 is a multiple of 8. # For efficiency, the last dimension of the offset_mask tensor in dcnv4 is a multiple of 8.
K3 = offset_mask.size(-1) K3 = offset_mask.size(-1)
K3_pad = int(math.ceil(K3/8)*8) K3_pad = int(math.ceil(K3 / 8) * 8)
pad_dim = K3_pad - K3 pad_dim = K3_pad - K3
offset_mask = torch.cat([offset_mask, offset_mask.new_zeros([*offset_mask.size()[:3], pad_dim])], -1) offset_mask = torch.cat([offset_mask, offset_mask.new_zeros([*offset_mask.size()[:3], pad_dim])], -1)
x = DCNv4Function.apply( x = DCNv4Function.apply(
x, offset_mask, x, offset_mask,
self.kernel_size, self.kernel_size, self.kernel_size, self.kernel_size,
......
...@@ -4,39 +4,34 @@ ...@@ -4,39 +4,34 @@
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
# -------------------------------------------------------- # --------------------------------------------------------
import os
import glob import glob
import os
import torch import torch
from setuptools import find_packages, setup
from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension
from torch.utils.cpp_extension import CUDA_HOME requirements = ['torch', 'torchvision']
from torch.utils.cpp_extension import CppExtension
from torch.utils.cpp_extension import CUDAExtension
from setuptools import find_packages
from setuptools import setup
requirements = ["torch", "torchvision"]
def get_extensions(): def get_extensions():
this_dir = os.path.dirname(os.path.abspath(__file__)) this_dir = os.path.dirname(os.path.abspath(__file__))
extensions_dir = os.path.join(this_dir, "src") extensions_dir = os.path.join(this_dir, 'src')
main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) main_file = glob.glob(os.path.join(extensions_dir, '*.cpp'))
source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) source_cpu = glob.glob(os.path.join(extensions_dir, 'cpu', '*.cpp'))
source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) source_cuda = glob.glob(os.path.join(extensions_dir, 'cuda', '*.cu'))
sources = main_file + source_cpu sources = main_file + source_cpu
extension = CppExtension extension = CppExtension
extra_compile_args = {"cxx": []} extra_compile_args = {'cxx': []}
define_macros = [] define_macros = []
if torch.cuda.is_available() and CUDA_HOME is not None: if torch.cuda.is_available() and CUDA_HOME is not None:
extension = CUDAExtension extension = CUDAExtension
sources += source_cuda sources += source_cuda
define_macros += [("WITH_CUDA", None)] define_macros += [('WITH_CUDA', None)]
extra_compile_args["nvcc"] = [ extra_compile_args['nvcc'] = [
# "-DCUDA_HAS_FP16=1", # "-DCUDA_HAS_FP16=1",
# "-D__CUDA_NO_HALF_OPERATORS__", # "-D__CUDA_NO_HALF_OPERATORS__",
# "-D__CUDA_NO_HALF_CONVERSIONS__", # "-D__CUDA_NO_HALF_CONVERSIONS__",
...@@ -49,7 +44,7 @@ def get_extensions(): ...@@ -49,7 +44,7 @@ def get_extensions():
include_dirs = [extensions_dir] include_dirs = [extensions_dir]
ext_modules = [ ext_modules = [
extension( extension(
"DCNv3", 'DCNv3',
sources, sources,
include_dirs=include_dirs, include_dirs=include_dirs,
define_macros=define_macros, define_macros=define_macros,
...@@ -60,16 +55,16 @@ def get_extensions(): ...@@ -60,16 +55,16 @@ def get_extensions():
setup( setup(
name="DCNv3", name='DCNv3',
version="1.0", version='1.0',
author="InternImage", author='InternImage',
url="https://github.com/OpenGVLab/InternImage", url='https://github.com/OpenGVLab/InternImage',
description= description=
"PyTorch Wrapper for CUDA Functions of DCNv3", 'PyTorch Wrapper for CUDA Functions of DCNv3',
packages=find_packages(exclude=( packages=find_packages(exclude=(
"configs", 'configs',
"tests", 'tests',
)), )),
ext_modules=get_extensions(), ext_modules=get_extensions(),
cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, cmdclass={'build_ext': torch.utils.cpp_extension.BuildExtension},
) )
...@@ -171,4 +171,4 @@ dcnv3_cuda_backward(const at::Tensor &input, const at::Tensor &offset, ...@@ -171,4 +171,4 @@ dcnv3_cuda_backward(const at::Tensor &input, const at::Tensor &offset,
} else { } else {
return {grad_input, grad_offset, grad_mask}; return {grad_input, grad_offset, grad_mask};
} }
} }
\ No newline at end of file
...@@ -1042,4 +1042,4 @@ void dcnv3_col2im_cuda( ...@@ -1042,4 +1042,4 @@ void dcnv3_col2im_cuda(
if (err != cudaSuccess) { if (err != cudaSuccess) {
printf("error in dcnv3_col2im_cuda: %s\n", cudaGetErrorString(err)); printf("error in dcnv3_col2im_cuda: %s\n", cudaGetErrorString(err));
} }
} }
\ No newline at end of file
...@@ -4,17 +4,15 @@ ...@@ -4,17 +4,15 @@
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
# -------------------------------------------------------- # --------------------------------------------------------
from __future__ import absolute_import from __future__ import absolute_import, division, print_function
from __future__ import print_function
from __future__ import division
import math
import time import time
import torch import torch
import torch.nn as nn import torch.nn as nn
import math
from torch.autograd import gradcheck
from functions.dcnv3_func import DCNv3Function, dcnv3_core_pytorch from functions.dcnv3_func import DCNv3Function, dcnv3_core_pytorch
from torch.autograd import gradcheck
H_in, W_in = 8, 8 H_in, W_in = 8, 8
N, M, D = 2, 4, 16 N, M, D = 2, 4, 16
......
...@@ -12,18 +12,17 @@ import time ...@@ -12,18 +12,17 @@ import time
import warnings import warnings
import mmcv import mmcv
import mmcv_custom # noqa: F401,F403
import mmseg_custom # noqa: F401,F403
import torch import torch
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (get_dist_info, init_dist, load_checkpoint, from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
wrap_fp16_model, load_state_dict) load_state_dict, wrap_fp16_model)
from mmcv.utils import DictAction from mmcv.utils import DictAction
from mmseg.apis import multi_gpu_test, single_gpu_test from mmseg.apis import multi_gpu_test, single_gpu_test
from mmseg.datasets import build_dataloader, build_dataset from mmseg.datasets import build_dataloader, build_dataset
from mmseg.models import build_segmentor from mmseg.models import build_segmentor
import mmcv_custom # noqa: F401,F403
import mmseg_custom # noqa: F401,F403
def parse_args(): def parse_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
...@@ -197,7 +196,7 @@ def main(): ...@@ -197,7 +196,7 @@ def main():
load_state_dict(model.module, checkpoint['state_dict'], strict=False) load_state_dict(model.module, checkpoint['state_dict'], strict=False)
else: else:
load_state_dict(model, checkpoint['state_dict'], strict=False) load_state_dict(model, checkpoint['state_dict'], strict=False)
if 'CLASSES' in checkpoint.get('meta', {}): if 'CLASSES' in checkpoint.get('meta', {}):
model.CLASSES = checkpoint['meta']['CLASSES'] model.CLASSES = checkpoint['meta']['CLASSES']
else: else:
......
...@@ -12,20 +12,19 @@ import time ...@@ -12,20 +12,19 @@ import time
import warnings import warnings
import mmcv import mmcv
import mmcv_custom # noqa: F401,F403
import mmseg_custom # noqa: F401,F403
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from mmcv.cnn.utils import revert_sync_batchnorm from mmcv.cnn.utils import revert_sync_batchnorm
from mmcv.runner import get_dist_info, init_dist from mmcv.runner import get_dist_info, init_dist
from mmcv.utils import Config, DictAction, get_git_hash from mmcv.utils import Config, DictAction, get_git_hash
from mmseg import __version__ from mmseg import __version__
from mmseg.apis import init_random_seed, set_random_seed, train_segmentor from mmseg.apis import init_random_seed, set_random_seed, train_segmentor
from mmseg.datasets import build_dataset from mmseg.datasets import build_dataset
from mmseg.models import build_segmentor from mmseg.models import build_segmentor
from mmseg.utils import (collect_env, get_device, get_root_logger, from mmseg.utils import (collect_env, get_device, get_root_logger,
setup_multi_processes) setup_multi_processes)
import mmcv_custom # noqa: F401,F403
import mmseg_custom # noqa: F401,F403
def parse_args(): def parse_args():
...@@ -231,10 +230,10 @@ def main(): ...@@ -231,10 +230,10 @@ def main():
model.CLASSES = datasets[0].CLASSES model.CLASSES = datasets[0].CLASSES
# passing checkpoint meta for saving best checkpoint # passing checkpoint meta for saving best checkpoint
meta.update(cfg.checkpoint_config.meta) meta.update(cfg.checkpoint_config.meta)
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True
train_segmentor(model, train_segmentor(model,
datasets, datasets,
cfg, cfg,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment