Unverified Commit 952f4806 authored by YosuaMichael's avatar YosuaMichael Committed by GitHub
Browse files

Refactor swin transfomer so later we can reuse component for 3d version (#6088)

* Use List[int] instead of int for window_size and shift_size

* Make PatchMerging and SwinTransformerBlock able to handle 2d and 3d cases

* Separate patch embedding from SwinTransformer and enable to get model without head by specifying num_heads=None

* Dont use if before padding so it is fx friendly

* Put the handling on window_size edge cases on separate function and wrap with torch.fx.wrap so it is excluded from tracing

* Update the weight url to the converted weight with new structure

* Update the accuracy of swin_transformer

* Change assert to Exception and nit

* Make num_classes optional

* Add typing output for _fix_window_and_shift_size function

* init head to None to make it jit scriptable

* Revert the change to make num_classes optional

* Revert unneccesarry changes that might be risky

* Remove self.head declaration
parent 1d50dfa0
...@@ -39,18 +39,23 @@ class PatchMerging(nn.Module): ...@@ -39,18 +39,23 @@ class PatchMerging(nn.Module):
self.norm = norm_layer(4 * dim) self.norm = norm_layer(4 * dim)
def forward(self, x: Tensor): def forward(self, x: Tensor):
B, H, W, C = x.shape """
Args:
x (Tensor): input tensor with expected layout of [..., H, W, C]
Returns:
Tensor with layout of [..., H/2, W/2, 2*C]
"""
H, W, _ = x.shape[-3:]
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C x0 = x[..., 0::2, 0::2, :] # ... H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C x1 = x[..., 1::2, 0::2, :] # ... H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C x2 = x[..., 0::2, 1::2, :] # ... H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C x3 = x[..., 1::2, 1::2, :] # ... H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C x = torch.cat([x0, x1, x2, x3], -1) # ... H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x) x = self.norm(x)
x = self.reduction(x) x = self.reduction(x) # ... H/2 W/2 2*C
x = x.view(B, H // 2, W // 2, 2 * C)
return x return x
...@@ -59,9 +64,9 @@ def shifted_window_attention( ...@@ -59,9 +64,9 @@ def shifted_window_attention(
qkv_weight: Tensor, qkv_weight: Tensor,
proj_weight: Tensor, proj_weight: Tensor,
relative_position_bias: Tensor, relative_position_bias: Tensor,
window_size: int, window_size: List[int],
num_heads: int, num_heads: int,
shift_size: int = 0, shift_size: List[int],
attention_dropout: float = 0.0, attention_dropout: float = 0.0,
dropout: float = 0.0, dropout: float = 0.0,
qkv_bias: Optional[Tensor] = None, qkv_bias: Optional[Tensor] = None,
...@@ -75,9 +80,9 @@ def shifted_window_attention( ...@@ -75,9 +80,9 @@ def shifted_window_attention(
qkv_weight (Tensor[in_dim, out_dim]): The weight tensor of query, key, value. qkv_weight (Tensor[in_dim, out_dim]): The weight tensor of query, key, value.
proj_weight (Tensor[out_dim, out_dim]): The weight tensor of projection. proj_weight (Tensor[out_dim, out_dim]): The weight tensor of projection.
relative_position_bias (Tensor): The learned relative position bias added to attention. relative_position_bias (Tensor): The learned relative position bias added to attention.
window_size (int): Window size. window_size (List[int]): Window size.
num_heads (int): Number of attention heads. num_heads (int): Number of attention heads.
shift_size (int): Shift size for shifted window attention. Default: 0. shift_size (List[int]): Shift size for shifted window attention.
attention_dropout (float): Dropout ratio of attention weight. Default: 0.0. attention_dropout (float): Dropout ratio of attention weight. Default: 0.0.
dropout (float): Dropout ratio of output. Default: 0.0. dropout (float): Dropout ratio of output. Default: 0.0.
qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None. qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None.
...@@ -87,23 +92,25 @@ def shifted_window_attention( ...@@ -87,23 +92,25 @@ def shifted_window_attention(
""" """
B, H, W, C = input.shape B, H, W, C = input.shape
# pad feature maps to multiples of window size # pad feature maps to multiples of window size
pad_r = (window_size - W % window_size) % window_size pad_r = (window_size[1] - W % window_size[1]) % window_size[1]
pad_b = (window_size - H % window_size) % window_size pad_b = (window_size[0] - H % window_size[0]) % window_size[0]
x = F.pad(input, (0, 0, 0, pad_r, 0, pad_b)) x = F.pad(input, (0, 0, 0, pad_r, 0, pad_b))
_, pad_H, pad_W, _ = x.shape _, pad_H, pad_W, _ = x.shape
# If window size is larger than feature size, there is no need to shift window. # If window size is larger than feature size, there is no need to shift window
if window_size == min(pad_H, pad_W): if window_size[0] >= pad_H:
shift_size = 0 shift_size[0] = 0
if window_size[1] >= pad_W:
shift_size[1] = 0
# cyclic shift # cyclic shift
if shift_size > 0: if sum(shift_size) > 0:
x = torch.roll(x, shifts=(-shift_size, -shift_size), dims=(1, 2)) x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2))
# partition windows # partition windows
num_windows = (pad_H // window_size) * (pad_W // window_size) num_windows = (pad_H // window_size[0]) * (pad_W // window_size[1])
x = x.view(B, pad_H // window_size, window_size, pad_W // window_size, window_size, C) x = x.view(B, pad_H // window_size[0], window_size[0], pad_W // window_size[1], window_size[1], C)
x = x.permute(0, 1, 3, 2, 4, 5).reshape(B * num_windows, window_size * window_size, C) # B*nW, Ws*Ws, C x = x.permute(0, 1, 3, 2, 4, 5).reshape(B * num_windows, window_size[0] * window_size[1], C) # B*nW, Ws*Ws, C
# multi-head attention # multi-head attention
qkv = F.linear(x, qkv_weight, qkv_bias) qkv = F.linear(x, qkv_weight, qkv_bias)
...@@ -114,17 +121,18 @@ def shifted_window_attention( ...@@ -114,17 +121,18 @@ def shifted_window_attention(
# add relative position bias # add relative position bias
attn = attn + relative_position_bias attn = attn + relative_position_bias
if shift_size > 0: if sum(shift_size) > 0:
# generate attention mask # generate attention mask
attn_mask = x.new_zeros((pad_H, pad_W)) attn_mask = x.new_zeros((pad_H, pad_W))
slices = ((0, -window_size), (-window_size, -shift_size), (-shift_size, None)) h_slices = ((0, -window_size[0]), (-window_size[0], -shift_size[0]), (-shift_size[0], None))
w_slices = ((0, -window_size[1]), (-window_size[1], -shift_size[1]), (-shift_size[1], None))
count = 0 count = 0
for h in slices: for h in h_slices:
for w in slices: for w in w_slices:
attn_mask[h[0] : h[1], w[0] : w[1]] = count attn_mask[h[0] : h[1], w[0] : w[1]] = count
count += 1 count += 1
attn_mask = attn_mask.view(pad_H // window_size, window_size, pad_W // window_size, window_size) attn_mask = attn_mask.view(pad_H // window_size[0], window_size[0], pad_W // window_size[1], window_size[1])
attn_mask = attn_mask.permute(0, 2, 1, 3).reshape(num_windows, window_size * window_size) attn_mask = attn_mask.permute(0, 2, 1, 3).reshape(num_windows, window_size[0] * window_size[1])
attn_mask = attn_mask.unsqueeze(1) - attn_mask.unsqueeze(2) attn_mask = attn_mask.unsqueeze(1) - attn_mask.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
attn = attn.view(x.size(0) // num_windows, num_windows, num_heads, x.size(1), x.size(1)) attn = attn.view(x.size(0) // num_windows, num_windows, num_heads, x.size(1), x.size(1))
...@@ -139,12 +147,12 @@ def shifted_window_attention( ...@@ -139,12 +147,12 @@ def shifted_window_attention(
x = F.dropout(x, p=dropout) x = F.dropout(x, p=dropout)
# reverse windows # reverse windows
x = x.view(B, pad_H // window_size, pad_W // window_size, window_size, window_size, C) x = x.view(B, pad_H // window_size[0], pad_W // window_size[1], window_size[0], window_size[1], C)
x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, pad_H, pad_W, C) x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, pad_H, pad_W, C)
# reverse cyclic shift # reverse cyclic shift
if shift_size > 0: if sum(shift_size) > 0:
x = torch.roll(x, shifts=(shift_size, shift_size), dims=(1, 2)) x = torch.roll(x, shifts=(shift_size[0], shift_size[1]), dims=(1, 2))
# unpad features # unpad features
x = x[:, :H, :W, :].contiguous() x = x[:, :H, :W, :].contiguous()
...@@ -162,8 +170,8 @@ class ShiftedWindowAttention(nn.Module): ...@@ -162,8 +170,8 @@ class ShiftedWindowAttention(nn.Module):
def __init__( def __init__(
self, self,
dim: int, dim: int,
window_size: int, window_size: List[int],
shift_size: int, shift_size: List[int],
num_heads: int, num_heads: int,
qkv_bias: bool = True, qkv_bias: bool = True,
proj_bias: bool = True, proj_bias: bool = True,
...@@ -171,6 +179,8 @@ class ShiftedWindowAttention(nn.Module): ...@@ -171,6 +179,8 @@ class ShiftedWindowAttention(nn.Module):
dropout: float = 0.0, dropout: float = 0.0,
): ):
super().__init__() super().__init__()
if len(window_size) != 2 or len(shift_size) != 2:
raise ValueError("window_size and shift_size must be of length 2")
self.window_size = window_size self.window_size = window_size
self.shift_size = shift_size self.shift_size = shift_size
self.num_heads = num_heads self.num_heads = num_heads
...@@ -182,29 +192,35 @@ class ShiftedWindowAttention(nn.Module): ...@@ -182,29 +192,35 @@ class ShiftedWindowAttention(nn.Module):
# define a parameter table of relative position bias # define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter( self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads) torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
) # 2*Wh-1 * 2*Ww-1, nH ) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window # get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size) coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size) coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij")) # 2, Wh, Ww coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij")) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size - 1 # shift to start from 0 relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size - 1 relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww*Wh*Ww relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww*Wh*Ww
self.register_buffer("relative_position_index", relative_position_index) self.register_buffer("relative_position_index", relative_position_index)
nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02) nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)
def forward(self, x: Tensor): def forward(self, x: Tensor):
"""
Args:
x (Tensor): Tensor with layout of [B, H, W, C]
Returns:
Tensor with same layout as input, i.e. [B, H, W, C]
"""
N = self.window_size[0] * self.window_size[1]
relative_position_bias = self.relative_position_bias_table[self.relative_position_index] # type: ignore[index] relative_position_bias = self.relative_position_bias_table[self.relative_position_index] # type: ignore[index]
relative_position_bias = relative_position_bias.view( relative_position_bias = relative_position_bias.view(N, N, -1)
self.window_size * self.window_size, self.window_size * self.window_size, -1
)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0) relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0)
return shifted_window_attention( return shifted_window_attention(
...@@ -228,31 +244,33 @@ class SwinTransformerBlock(nn.Module): ...@@ -228,31 +244,33 @@ class SwinTransformerBlock(nn.Module):
Args: Args:
dim (int): Number of input channels. dim (int): Number of input channels.
num_heads (int): Number of attention heads. num_heads (int): Number of attention heads.
window_size (int): Window size. Default: 7. window_size (List[int]): Window size.
shift_size (int): Shift size for shifted window attention. Default: 0. shift_size (List[int]): Shift size for shifted window attention.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
dropout (float): Dropout rate. Default: 0.0. dropout (float): Dropout rate. Default: 0.0.
attention_dropout (float): Attention dropout rate. Default: 0.0. attention_dropout (float): Attention dropout rate. Default: 0.0.
stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0. stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0.
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
attn_layer (nn.Module): Attention layer. Default: ShiftedWindowAttention
""" """
def __init__( def __init__(
self, self,
dim: int, dim: int,
num_heads: int, num_heads: int,
window_size: int = 7, window_size: List[int],
shift_size: int = 0, shift_size: List[int],
mlp_ratio: float = 4.0, mlp_ratio: float = 4.0,
dropout: float = 0.0, dropout: float = 0.0,
attention_dropout: float = 0.0, attention_dropout: float = 0.0,
stochastic_depth_prob: float = 0.0, stochastic_depth_prob: float = 0.0,
norm_layer: Callable[..., nn.Module] = nn.LayerNorm, norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
attn_layer: Callable[..., nn.Module] = ShiftedWindowAttention,
): ):
super().__init__() super().__init__()
self.norm1 = norm_layer(dim) self.norm1 = norm_layer(dim)
self.attn = ShiftedWindowAttention( self.attn = attn_layer(
dim, dim,
window_size, window_size,
shift_size, shift_size,
...@@ -281,11 +299,11 @@ class SwinTransformer(nn.Module): ...@@ -281,11 +299,11 @@ class SwinTransformer(nn.Module):
Implements Swin Transformer from the `"Swin Transformer: Hierarchical Vision Transformer using Implements Swin Transformer from the `"Swin Transformer: Hierarchical Vision Transformer using
Shifted Windows" <https://arxiv.org/pdf/2103.14030>`_ paper. Shifted Windows" <https://arxiv.org/pdf/2103.14030>`_ paper.
Args: Args:
patch_size (int): Patch size. patch_size (List[int]): Patch size.
embed_dim (int): Patch embedding dimension. embed_dim (int): Patch embedding dimension.
depths (List(int)): Depth of each Swin Transformer layer. depths (List(int)): Depth of each Swin Transformer layer.
num_heads (List(int)): Number of attention heads in different layers. num_heads (List(int)): Number of attention heads in different layers.
window_size (int): Window size. Default: 7. window_size (List[int]): Window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
dropout (float): Dropout rate. Default: 0.0. dropout (float): Dropout rate. Default: 0.0.
attention_dropout (float): Attention dropout rate. Default: 0.0. attention_dropout (float): Attention dropout rate. Default: 0.0.
...@@ -297,11 +315,11 @@ class SwinTransformer(nn.Module): ...@@ -297,11 +315,11 @@ class SwinTransformer(nn.Module):
def __init__( def __init__(
self, self,
patch_size: int, patch_size: List[int],
embed_dim: int, embed_dim: int,
depths: List[int], depths: List[int],
num_heads: List[int], num_heads: List[int],
window_size: int = 7, window_size: List[int],
mlp_ratio: float = 4.0, mlp_ratio: float = 4.0,
dropout: float = 0.0, dropout: float = 0.0,
attention_dropout: float = 0.0, attention_dropout: float = 0.0,
...@@ -324,7 +342,9 @@ class SwinTransformer(nn.Module): ...@@ -324,7 +342,9 @@ class SwinTransformer(nn.Module):
# split image into non-overlapping patches # split image into non-overlapping patches
layers.append( layers.append(
nn.Sequential( nn.Sequential(
nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size), nn.Conv2d(
3, embed_dim, kernel_size=(patch_size[0], patch_size[1]), stride=(patch_size[0], patch_size[1])
),
Permute([0, 2, 3, 1]), Permute([0, 2, 3, 1]),
norm_layer(embed_dim), norm_layer(embed_dim),
) )
...@@ -344,7 +364,7 @@ class SwinTransformer(nn.Module): ...@@ -344,7 +364,7 @@ class SwinTransformer(nn.Module):
dim, dim,
num_heads[i_stage], num_heads[i_stage],
window_size=window_size, window_size=window_size,
shift_size=0 if i_layer % 2 == 0 else window_size // 2, shift_size=[0 if i_layer % 2 == 0 else w // 2 for w in window_size],
mlp_ratio=mlp_ratio, mlp_ratio=mlp_ratio,
dropout=dropout, dropout=dropout,
attention_dropout=attention_dropout, attention_dropout=attention_dropout,
...@@ -381,11 +401,11 @@ class SwinTransformer(nn.Module): ...@@ -381,11 +401,11 @@ class SwinTransformer(nn.Module):
def _swin_transformer( def _swin_transformer(
patch_size: int, patch_size: List[int],
embed_dim: int, embed_dim: int,
depths: List[int], depths: List[int],
num_heads: List[int], num_heads: List[int],
window_size: int, window_size: List[int],
stochastic_depth_prob: float, stochastic_depth_prob: float,
weights: Optional[WeightsEnum], weights: Optional[WeightsEnum],
progress: bool, progress: bool,
...@@ -508,11 +528,11 @@ def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, * ...@@ -508,11 +528,11 @@ def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, *
weights = Swin_T_Weights.verify(weights) weights = Swin_T_Weights.verify(weights)
return _swin_transformer( return _swin_transformer(
patch_size=4, patch_size=[4, 4],
embed_dim=96, embed_dim=96,
depths=[2, 2, 6, 2], depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24], num_heads=[3, 6, 12, 24],
window_size=7, window_size=[7, 7],
stochastic_depth_prob=0.2, stochastic_depth_prob=0.2,
weights=weights, weights=weights,
progress=progress, progress=progress,
...@@ -544,11 +564,11 @@ def swin_s(*, weights: Optional[Swin_S_Weights] = None, progress: bool = True, * ...@@ -544,11 +564,11 @@ def swin_s(*, weights: Optional[Swin_S_Weights] = None, progress: bool = True, *
weights = Swin_S_Weights.verify(weights) weights = Swin_S_Weights.verify(weights)
return _swin_transformer( return _swin_transformer(
patch_size=4, patch_size=[4, 4],
embed_dim=96, embed_dim=96,
depths=[2, 2, 18, 2], depths=[2, 2, 18, 2],
num_heads=[3, 6, 12, 24], num_heads=[3, 6, 12, 24],
window_size=7, window_size=[7, 7],
stochastic_depth_prob=0.3, stochastic_depth_prob=0.3,
weights=weights, weights=weights,
progress=progress, progress=progress,
...@@ -580,11 +600,11 @@ def swin_b(*, weights: Optional[Swin_B_Weights] = None, progress: bool = True, * ...@@ -580,11 +600,11 @@ def swin_b(*, weights: Optional[Swin_B_Weights] = None, progress: bool = True, *
weights = Swin_B_Weights.verify(weights) weights = Swin_B_Weights.verify(weights)
return _swin_transformer( return _swin_transformer(
patch_size=4, patch_size=[4, 4],
embed_dim=128, embed_dim=128,
depths=[2, 2, 18, 2], depths=[2, 2, 18, 2],
num_heads=[4, 8, 16, 32], num_heads=[4, 8, 16, 32],
window_size=7, window_size=[7, 7],
stochastic_depth_prob=0.5, stochastic_depth_prob=0.5,
weights=weights, weights=weights,
progress=progress, progress=progress,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment