"vscode:/vscode.git/clone" did not exist on "cb4c3c8c66405caca780ccf3d4af3d8cc581f9fd"
Unverified Commit c310d28c authored by mengpenghui's avatar mengpenghui Committed by GitHub
Browse files

[Feature] Add MLU support for DCN (#2540)

parent c9d477bb
...@@ -18,7 +18,7 @@ We implement common ops used in detection, segmentation, etc. ...@@ -18,7 +18,7 @@ We implement common ops used in detection, segmentation, etc.
| ConvexIoU | | √ | | | | | ConvexIoU | | √ | | | |
| CornerPool | | √ | | | | | CornerPool | | √ | | | |
| Correlation | | √ | | | | | Correlation | | √ | | | |
| Deformable Convolution v1/v2 | √ | √ | | | √ | | Deformable Convolution v1/v2 | √ | √ | | | √ |
| Deformable RoIPool | | √ | √ | | √ | | Deformable RoIPool | | √ | √ | | √ |
| DiffIoURotated | | √ | | | | | DiffIoURotated | | √ | | | |
| DynamicScatter | | √ | | | | | DynamicScatter | | √ | | | |
......
...@@ -18,7 +18,7 @@ MMCV 提供了检测、分割等任务中常用的算子 ...@@ -18,7 +18,7 @@ MMCV 提供了检测、分割等任务中常用的算子
| ConvexIoU | | √ | | | | | ConvexIoU | | √ | | | |
| CornerPool | | √ | | | | | CornerPool | | √ | | | |
| Correlation | | √ | | | | | Correlation | | √ | | | |
| Deformable Convolution v1/v2 | √ | √ | | | √ | | Deformable Convolution v1/v2 | √ | √ | | | √ |
| Deformable RoIPool | | √ | √ | | √ | | Deformable RoIPool | | √ | √ | | √ |
| DiffIoURotated | | √ | | | | | DiffIoURotated | | √ | | | |
| DynamicScatter | | √ | | | | | DynamicScatter | | √ | | | |
......
...@@ -109,6 +109,7 @@ __all__ = [ ...@@ -109,6 +109,7 @@ __all__ = [
] ]
if IS_MLU_AVAILABLE: if IS_MLU_AVAILABLE:
from .deform_conv import DeformConv2dPack_MLU # noqa:F401
from .modulated_deform_conv import \ from .modulated_deform_conv import \
ModulatedDeformConv2dPack_MLU # noqa:F401 ModulatedDeformConv2dPack_MLU # noqa:F401
__all__.append('ModulatedDeformConv2dPack_MLU') __all__.extend(['ModulatedDeformConv2dPack_MLU', 'DeformConv2dPack_MLU'])
...@@ -9,7 +9,7 @@ from torch.autograd import Function ...@@ -9,7 +9,7 @@ from torch.autograd import Function
from torch.autograd.function import once_differentiable from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair, _single from torch.nn.modules.utils import _pair, _single
from mmcv.utils import deprecated_api_warning from mmcv.utils import IS_MLU_AVAILABLE, deprecated_api_warning
from ..cnn import CONV_LAYERS from ..cnn import CONV_LAYERS
from ..utils import ext_loader, print_log from ..utils import ext_loader, print_log
from .modulated_deform_conv import ModulatedDeformConv2dFunction from .modulated_deform_conv import ModulatedDeformConv2dFunction
...@@ -434,3 +434,67 @@ class DeformConv2dPack(DeformConv2d): ...@@ -434,3 +434,67 @@ class DeformConv2dPack(DeformConv2d):
super()._load_from_state_dict(state_dict, prefix, local_metadata, super()._load_from_state_dict(state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, strict, missing_keys, unexpected_keys,
error_msgs) error_msgs)
if IS_MLU_AVAILABLE:
import torchvision
from mmcv.utils import digit_version
assert digit_version(torchvision.__version__) >= digit_version(
'0.10.0a0'), 'the version of torchvision should be >= 0.10.0'
from torchvision.ops import deform_conv2d as tv_deform_conv2d
@CONV_LAYERS.register_module('DCN', force=True)
class DeformConv2dPack_MLU(DeformConv2d):
"""This class is the DCN implementation of the MLU device. The MLU
backend support of the operator has been implemented in torchvision.
The mmcv registration mechanism is used for multiplexing here. The
torchvision implementation of DCN is called.
Args:
in_channels (int): Same as nn.Conv2d.
out_channels (int): Same as nn.Conv2d.
kernel_size (int or tuple[int]): Same as nn.Conv2d.
stride (int): Same as nn.Conv2d, while tuple is not supported.
padding (int): Same as nn.Conv2d, while tuple is not supported.
dilation (int): Same as nn.Conv2d, while tuple is not supported.
groups (int): Same as nn.Conv2d.
bias (bool or str): If specified as `auto`, it will be decided by
the norm_cfg. Bias will be set as True if norm_cfg is None,
otherwise False.
im2col_step (int): Number of samples processed by
im2col_cuda_kernel per call. It will work when ``batch_size``
> ``im2col_step``, but ``batch_size`` must be divisible by
``im2col_step``. Default: 32. `New in version 1.7.2.
Currently not supported on MLU devices.`
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.conv_offset = nn.Conv2d(
self.in_channels,
self.deform_groups * 2 * self.kernel_size[0] *
self.kernel_size[1],
kernel_size=self.kernel_size,
stride=_pair(self.stride),
padding=_pair(self.padding),
dilation=_pair(self.dilation),
bias=True)
self.init_offset()
def init_offset(self):
self.conv_offset.weight.data.zero_()
self.conv_offset.bias.data.zero_()
def forward(self, x: Tensor) -> Tensor: # type: ignore
cur_im2col_step = min(self.im2col_step, x.size(0))
assert (x.size(0) % cur_im2col_step
) == 0, 'batch size must be divisible by im2col_step'
offset = self.conv_offset(x)
x = x.type_as(offset)
weight = self.weight
weight = weight.type_as(x)
return tv_deform_conv2d(x, offset, weight, None, self.stride,
self.padding, self.dilation)
...@@ -3,7 +3,7 @@ import numpy as np ...@@ -3,7 +3,7 @@ import numpy as np
import pytest import pytest
import torch import torch
from mmcv.utils import TORCH_VERSION, digit_version from mmcv.utils import IS_MLU_AVAILABLE, TORCH_VERSION, digit_version
try: try:
# If PyTorch version >= 1.6.0 and fp16 is enabled, torch.cuda.amp.autocast # If PyTorch version >= 1.6.0 and fp16 is enabled, torch.cuda.amp.autocast
...@@ -45,7 +45,10 @@ class TestDeformconv: ...@@ -45,7 +45,10 @@ class TestDeformconv:
im2col_step=2): im2col_step=2):
if not torch.cuda.is_available() and device == 'cuda': if not torch.cuda.is_available() and device == 'cuda':
pytest.skip('test requires GPU') pytest.skip('test requires GPU')
from mmcv.ops import DeformConv2dPack if device == 'mlu':
from mmcv.ops import DeformConv2dPack_MLU as DeformConv2dPack
else:
from mmcv.ops import DeformConv2dPack
c_in = 1 c_in = 1
c_out = 1 c_out = 1
batch_size = 10 batch_size = 10
...@@ -69,6 +72,8 @@ class TestDeformconv: ...@@ -69,6 +72,8 @@ class TestDeformconv:
torch.Tensor(deform_weight).reshape(1, 1, 2, 2)) torch.Tensor(deform_weight).reshape(1, 1, 2, 2))
if device == 'cuda': if device == 'cuda':
model.cuda() model.cuda()
elif device == 'mlu':
model.mlu()
model.type(dtype) model.type(dtype)
out = model(x) out = model(x)
...@@ -108,6 +113,7 @@ class TestDeformconv: ...@@ -108,6 +113,7 @@ class TestDeformconv:
def _test_amp_deformconv(self, def _test_amp_deformconv(self,
input_dtype, input_dtype,
threshold=1e-3, threshold=1e-3,
device='cuda',
batch_size=10, batch_size=10,
im2col_step=2): im2col_step=2):
"""The function to test amp released on pytorch 1.6.0. """The function to test amp released on pytorch 1.6.0.
...@@ -120,15 +126,18 @@ class TestDeformconv: ...@@ -120,15 +126,18 @@ class TestDeformconv:
input_dtype: torch.float or torch.half. input_dtype: torch.float or torch.half.
threshold: the same as above function. threshold: the same as above function.
""" """
if not torch.cuda.is_available(): if not torch.cuda.is_available() and device == 'cuda':
return return
from mmcv.ops import DeformConv2dPack if device == 'mlu':
from mmcv.ops import DeformConv2dPack_MLU as DeformConv2dPack
else:
from mmcv.ops import DeformConv2dPack
c_in = 1 c_in = 1
c_out = 1 c_out = 1
repeated_input = np.repeat(input, batch_size, axis=0) repeated_input = np.repeat(input, batch_size, axis=0)
repeated_gt_out = np.repeat(gt_out, batch_size, axis=0) repeated_gt_out = np.repeat(gt_out, batch_size, axis=0)
repeated_gt_x_grad = np.repeat(gt_x_grad, batch_size, axis=0) repeated_gt_x_grad = np.repeat(gt_x_grad, batch_size, axis=0)
x = torch.Tensor(repeated_input).cuda().type(input_dtype) x = torch.Tensor(repeated_input).to(device).type(input_dtype)
x.requires_grad = True x.requires_grad = True
model = DeformConv2dPack( model = DeformConv2dPack(
in_channels=c_in, in_channels=c_in,
...@@ -143,7 +152,10 @@ class TestDeformconv: ...@@ -143,7 +152,10 @@ class TestDeformconv:
torch.Tensor(offset_bias).reshape(8)) torch.Tensor(offset_bias).reshape(8))
model.weight.data = torch.nn.Parameter( model.weight.data = torch.nn.Parameter(
torch.Tensor(deform_weight).reshape(1, 1, 2, 2)) torch.Tensor(deform_weight).reshape(1, 1, 2, 2))
model.cuda() if device == 'cuda':
model.cuda()
elif device == 'mlu':
model.mlu()
out = model(x) out = model(x)
out.backward(torch.ones_like(out)) out.backward(torch.ones_like(out))
...@@ -180,21 +192,25 @@ class TestDeformconv: ...@@ -180,21 +192,25 @@ class TestDeformconv:
def test_deformconv(self): def test_deformconv(self):
self._test_deformconv(torch.double, device='cpu') self._test_deformconv(torch.double, device='cpu')
self._test_deformconv(torch.float, device='cpu', threshold=1e-1) self._test_deformconv(torch.float, device='cpu', threshold=1e-1)
self._test_deformconv(torch.double)
self._test_deformconv(torch.float) device = 'mlu' if IS_MLU_AVAILABLE else 'cuda'
self._test_deformconv(torch.half, threshold=1e-1) self._test_deformconv(torch.double, device=device)
self._test_deformconv(torch.float, device=device)
self._test_deformconv(torch.half, threshold=1e-1, device=device)
# test batch_size < im2col_step # test batch_size < im2col_step
self._test_deformconv(torch.float, batch_size=1, im2col_step=2) self._test_deformconv(
torch.float, batch_size=1, im2col_step=2, device=device)
# test bach_size % im2col_step != 0 # test bach_size % im2col_step != 0
with pytest.raises( with pytest.raises(
AssertionError, AssertionError,
match='batch size must be divisible by im2col_step'): match='batch size must be divisible by im2col_step'):
self._test_deformconv(torch.float, batch_size=10, im2col_step=3) self._test_deformconv(
torch.float, batch_size=10, im2col_step=3, device=device)
# test amp when torch version >= '1.6.0', the type of # test amp when torch version >= '1.6.0', the type of
# input data for deformconv might be torch.float or torch.half # input data for deformconv might be torch.float or torch.half
if (TORCH_VERSION != 'parrots' if (TORCH_VERSION != 'parrots'
and digit_version(TORCH_VERSION) >= digit_version('1.6.0')): and digit_version(TORCH_VERSION) >= digit_version('1.6.0')):
with autocast(enabled=True): with autocast(enabled=True):
self._test_amp_deformconv(torch.float, 1e-1) self._test_amp_deformconv(torch.float, 1e-1, device)
self._test_amp_deformconv(torch.half, 1e-1) self._test_amp_deformconv(torch.half, 1e-1, device)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment