Unverified Commit 09f4b813 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Fix typing issue to make DeformConv2d scriptable (#4079)

parent 562b8463
...@@ -775,6 +775,10 @@ class TestDeformConv: ...@@ -775,6 +775,10 @@ class TestDeformConv:
with torch.cuda.amp.autocast(): with torch.cuda.amp.autocast():
self.test_forward(torch.device("cuda"), contiguous=False, batch_sz=batch_sz, dtype=dtype) self.test_forward(torch.device("cuda"), contiguous=False, batch_sz=batch_sz, dtype=dtype)
def test_forward_scriptability(self):
# Non-regression test for https://github.com/pytorch/vision/issues/4078
torch.jit.script(ops.DeformConv2d(in_channels=8, out_channels=8, kernel_size=3))
class TestFrozenBNT: class TestFrozenBNT:
def test_frozenbatchnorm2d_repr(self): def test_frozenbatchnorm2d_repr(self):
......
...@@ -149,7 +149,7 @@ class DeformConv2d(nn.Module): ...@@ -149,7 +149,7 @@ class DeformConv2d(nn.Module):
bound = 1 / math.sqrt(fan_in) bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound) init.uniform_(self.bias, -bound, bound)
def forward(self, input: Tensor, offset: Tensor, mask: Tensor = None) -> Tensor: def forward(self, input: Tensor, offset: Tensor, mask: Optional[Tensor] = None) -> Tensor:
""" """
Args: Args:
input (Tensor[batch_size, in_channels, in_height, in_width]): input tensor input (Tensor[batch_size, in_channels, in_height, in_width]): input tensor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment