Unverified Commit e75a3337 authored by YosuaMichael's avatar YosuaMichael Committed by GitHub
Browse files

Move out the pad operation from PatchMerging in swin transformer to make it fx compatible (#6252)

parent f14682a8
......@@ -25,6 +25,15 @@ __all__ = [
]
def _patch_merging_pad(x):
H, W, _ = x.shape[-3:]
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
return x
torch.fx.wrap("_patch_merging_pad")
class PatchMerging(nn.Module):
"""Patch Merging Layer.
Args:
......@@ -46,8 +55,7 @@ class PatchMerging(nn.Module):
Returns:
Tensor with layout of [..., H/2, W/2, 2*C]
"""
H, W, _ = x.shape[-3:]
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
x = _patch_merging_pad(x)
x0 = x[..., 0::2, 0::2, :] # ... H/2 W/2 C
x1 = x[..., 1::2, 0::2, :] # ... H/2 W/2 C
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment