"docs/vscode:/vscode.git/clone" did not exist on "a2b11de487269f81c6cdbe17ac9fa4c8c585da1b"
Unverified Commit 71ee2a61 authored by mengpenghui's avatar mengpenghui Committed by GitHub
Browse files

[Enhance] Add AMP support for MLU_DCNv2 (#2548)

parent c310d28c
......@@ -406,10 +406,13 @@ if IS_MLU_AVAILABLE:
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask)
x = x.type_as(offset)
weight = self.weight.type_as(x)
mask = mask.type_as(x)
return tv_deform_conv2d(
x,
offset,
self.weight,
weight,
bias=self.bias,
stride=self.stride,
padding=self.padding,
......
......@@ -74,7 +74,7 @@ class TestMdconv:
assert numpy.allclose(dcn.conv_offset.bias.grad.cpu().detach().numpy(),
dcn_offset_b_grad, 1e-2)
def _test_amp_mdconv(self, input_dtype=torch.float):
def _test_amp_mdconv(self, input_dtype=torch.float, device='cuda'):
"""The function to test amp released on pytorch 1.6.0.
The type of input data might be torch.float or torch.half,
......@@ -84,10 +84,15 @@ class TestMdconv:
Args:
input_dtype: torch.float or torch.half.
"""
if not torch.cuda.is_available():
if not torch.cuda.is_available() and device == 'cuda':
return
from mmcv.ops import ModulatedDeformConv2dPack
input = torch.tensor(input_t).cuda().type(input_dtype)
if device == 'mlu':
from mmcv.ops import \
ModulatedDeformConv2dPack_MLU as ModulatedDeformConv2dPack
else:
from mmcv.ops import ModulatedDeformConv2dPack
input = torch.tensor(input_t).to(device).type(input_dtype)
input.requires_grad = True
dcn = ModulatedDeformConv2dPack(
......@@ -97,7 +102,7 @@ class TestMdconv:
stride=1,
padding=1,
deform_groups=1,
bias=False).cuda()
bias=False).to(device)
dcn.weight.data.fill_(1.)
output = dcn(input)
output.sum().backward()
......@@ -126,5 +131,5 @@ class TestMdconv:
if (TORCH_VERSION != 'parrots'
and digit_version(TORCH_VERSION) >= digit_version('1.6.0')):
with autocast(enabled=True):
self._test_amp_mdconv(torch.float)
self._test_amp_mdconv(torch.half)
self._test_amp_mdconv(torch.float, device=device)
self._test_amp_mdconv(torch.half, device=device)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment