Unverified Commit 71ee2a61 authored by mengpenghui's avatar mengpenghui Committed by GitHub
Browse files

[Enhance] Add AMP support for MLU_DCNv2 (#2548)

parent c310d28c
...@@ -406,10 +406,13 @@ if IS_MLU_AVAILABLE: ...@@ -406,10 +406,13 @@ if IS_MLU_AVAILABLE:
o1, o2, mask = torch.chunk(out, 3, dim=1) o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1) offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask) mask = torch.sigmoid(mask)
x = x.type_as(offset)
weight = self.weight.type_as(x)
mask = mask.type_as(x)
return tv_deform_conv2d( return tv_deform_conv2d(
x, x,
offset, offset,
self.weight, weight,
bias=self.bias, bias=self.bias,
stride=self.stride, stride=self.stride,
padding=self.padding, padding=self.padding,
......
...@@ -74,7 +74,7 @@ class TestMdconv: ...@@ -74,7 +74,7 @@ class TestMdconv:
assert numpy.allclose(dcn.conv_offset.bias.grad.cpu().detach().numpy(), assert numpy.allclose(dcn.conv_offset.bias.grad.cpu().detach().numpy(),
dcn_offset_b_grad, 1e-2) dcn_offset_b_grad, 1e-2)
def _test_amp_mdconv(self, input_dtype=torch.float): def _test_amp_mdconv(self, input_dtype=torch.float, device='cuda'):
"""The function to test amp released on pytorch 1.6.0. """The function to test amp released on pytorch 1.6.0.
The type of input data might be torch.float or torch.half, The type of input data might be torch.float or torch.half,
...@@ -84,10 +84,15 @@ class TestMdconv: ...@@ -84,10 +84,15 @@ class TestMdconv:
Args: Args:
input_dtype: torch.float or torch.half. input_dtype: torch.float or torch.half.
""" """
if not torch.cuda.is_available(): if not torch.cuda.is_available() and device == 'cuda':
return return
if device == 'mlu':
from mmcv.ops import \
ModulatedDeformConv2dPack_MLU as ModulatedDeformConv2dPack
else:
from mmcv.ops import ModulatedDeformConv2dPack from mmcv.ops import ModulatedDeformConv2dPack
input = torch.tensor(input_t).cuda().type(input_dtype)
input = torch.tensor(input_t).to(device).type(input_dtype)
input.requires_grad = True input.requires_grad = True
dcn = ModulatedDeformConv2dPack( dcn = ModulatedDeformConv2dPack(
...@@ -97,7 +102,7 @@ class TestMdconv: ...@@ -97,7 +102,7 @@ class TestMdconv:
stride=1, stride=1,
padding=1, padding=1,
deform_groups=1, deform_groups=1,
bias=False).cuda() bias=False).to(device)
dcn.weight.data.fill_(1.) dcn.weight.data.fill_(1.)
output = dcn(input) output = dcn(input)
output.sum().backward() output.sum().backward()
...@@ -126,5 +131,5 @@ class TestMdconv: ...@@ -126,5 +131,5 @@ class TestMdconv:
if (TORCH_VERSION != 'parrots' if (TORCH_VERSION != 'parrots'
and digit_version(TORCH_VERSION) >= digit_version('1.6.0')): and digit_version(TORCH_VERSION) >= digit_version('1.6.0')):
with autocast(enabled=True): with autocast(enabled=True):
self._test_amp_mdconv(torch.float) self._test_amp_mdconv(torch.float, device=device)
self._test_amp_mdconv(torch.half) self._test_amp_mdconv(torch.half, device=device)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment