Unverified Commit 0a375614 authored by ddonatien's avatar ddonatien Committed by GitHub
Browse files

Prevent kernal_normalizer to change mask dtype (#1210)

parent 17fa6670
...@@ -268,7 +268,7 @@ class CARAFEPack(nn.Module): ...@@ -268,7 +268,7 @@ class CARAFEPack(nn.Module):
mask_channel = int(mask_c / float(self.up_kernel**2)) mask_channel = int(mask_c / float(self.up_kernel**2))
mask = mask.view(n, mask_channel, -1, h, w) mask = mask.view(n, mask_channel, -1, h, w)
mask = F.softmax(mask, dim=2) mask = F.softmax(mask, dim=2, dtype=mask.dtype)
mask = mask.view(n, mask_c, h, w).contiguous() mask = mask.view(n, mask_c, h, w).contiguous()
return mask return mask
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment