import logging import math from typing import Callable, List, Optional, Tuple, Union import torch import torch.nn as nn import numpy as np from itertools import repeat import collections.abc import torch from torch import nn from torch.nn import functional as F from torch.nn.init import trunc_normal_ from torchvision import transforms from torchvision.transforms import InterpolationMode from functools import partial from itertools import repeat import collections.abc # From PyTorch internals def _ntuple(n): def parse(x): if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): return tuple(x) return tuple(repeat(x, n)) return parse to_1tuple = _ntuple(1) to_2tuple = _ntuple(2) to_3tuple = _ntuple(3) to_4tuple = _ntuple(4) to_ntuple = _ntuple def make_divisible(v, divisor=8, min_value=None, round_limit=.9): min_value = min_value or divisor new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) # Make sure that round down does not go down by more than 10%. if new_v < round_limit * v: new_v += divisor return new_v def extend_tuple(x, n): # pads a tuple to specified n by padding with last value if not isinstance(x, (tuple, list)): x = (x,) else: x = tuple(x) pad_n = n - len(x) if pad_n <= 0: return x[:n] return x + (x[-1],) * pad_n class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): super(DropPath, self).__init__() self.drop_prob = drop_prob self.scale_by_keep = scale_by_keep def forward(self, x): return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) def extra_repr(self): return f'drop_prob={round(self.drop_prob,3):0.3f}' class Mlp(nn.Module): """ MLP as used in Vision Transformer, MLP-Mixer and related networks """ def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, norm_layer=None, bias=True, drop=0., use_conv=False, ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features bias = to_2tuple(bias) drop_probs = to_2tuple(drop) linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear self.fc1 = linear_layer(in_features, hidden_features, bias=False) self.act = act_layer() self.drop1 = nn.Dropout(0.05) self.fc2 = linear_layer(hidden_features, out_features, bias=False) self.scale = nn.Parameter(torch.ones(1)) with torch.no_grad(): nn.init.kaiming_uniform_(self.fc1.weight, a=math.sqrt(5)) nn.init.zeros_(self.fc2.weight) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop1(x) x = self.fc2(x) x = self.scale*x return x # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): """ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) """ grid_h = np.arange(grid_size, dtype=np.float32) grid_w = np.arange(grid_size, dtype=np.float32) grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) grid = grid.reshape([2, 1, grid_size, grid_size]) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) if cls_token: pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) return pos_embed def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): assert embed_dim % 2 == 0 # use half of dimensions to encode grid_h emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) return emb def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): """ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) """ assert embed_dim % 2 == 0 omega = np.arange(embed_dim // 2, dtype=np.float32) omega /= embed_dim / 2. omega = 1. / 10000**omega # (D/2,) pos = pos.reshape(-1) # (M,) out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product emb_sin = np.sin(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2) emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) return emb # From PyTorch internals def _ntuple(n): def parse(x): if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): return tuple(x) return tuple(repeat(x, n)) return parse to_1tuple = _ntuple(1) to_2tuple = _ntuple(2) to_3tuple = _ntuple(3) to_4tuple = _ntuple(4) _int_or_tuple_2_t = Union[int, Tuple[int, int]] def window_partition( x: torch.Tensor, window_size: Tuple[int, int], ) -> torch.Tensor: """ Partition into non-overlapping windows with padding if needed. Args: x (tensor): input tokens with [B, H, W, C]. window_size (int): window size. Returns: windows: windows after partition with [B * num_windows, window_size, window_size, C]. (Hp, Wp): padded height and width before partition """ B, H, W, C = x.shape x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C) return windows def window_reverse(windows, window_size: Tuple[int, int], H: int, W: int): """ Args: windows: (num_windows*B, window_size, window_size, C) window_size (int): Window size H (int): Height of image W (int): Width of image Returns: x: (B, H, W, C) """ C = windows.shape[-1] x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C) return x def get_relative_position_index(win_h: int, win_w: int): # get pair-wise relative position index for each token inside the window coords = torch.stack(torch.meshgrid([torch.arange(win_h), torch.arange(win_w)])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += win_h - 1 # shift to start from 0 relative_coords[:, :, 1] += win_w - 1 relative_coords[:, :, 0] *= 2 * win_w - 1 return relative_coords.sum(-1) # Wh*Ww, Wh*Ww class PatchMerging(nn.Module): r""" Patch Merging Layer. Args: input_resolution (tuple[int]): Resolution of input feature. dim (int): Number of input channels. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): super().__init__() self.input_resolution = input_resolution self.dim = dim self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) self.reduction_2 = nn.Linear(2 * dim, dim, bias=False) self.norm = norm_layer(4 * dim) self.norm_2 = norm_layer(2 * dim) def forward(self, x): """ X bxcxgxg x: B, H*W, C """ size= self.input_resolution B, C, G,_ = x.shape assert G*G == size * size, "input feature has wrong size" x = x.reshape(x.shape[0],x.shape[1],-1) #bxcxl x = x.permute(0,2,1) #bxlxc B, _, C = x.shape x = x.view(B, size, size, C) x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C x = self.norm(x) x = self.reduction(x) x = self.norm_2(x) x = self.reduction_2(x) x = x.view(B,-1,C) x = x.permute(0,2,1) #bxcxl x = x.reshape(x.shape[0],x.shape[1],G//2,G//2) #bxcxl return x class WindowAttention(nn.Module): """ Window based multi-head self attention (W-MSA) module with relative position bias. It supports shifted and non-shifted windows. """ fused_attn: torch.jit.Final[bool] def __init__( self, dim: int, num_heads: int, head_dim: Optional[int] = None, window_size: _int_or_tuple_2_t = 7, qkv_bias: bool = True, attn_drop: float = 0., proj_drop: float = 0., ): """ Args: dim: Number of input channels. num_heads: Number of attention heads. head_dim: Number of channels per head (dim // num_heads if not set) window_size: The height and width of the window. qkv_bias: If True, add a learnable bias to query, key, value. attn_drop: Dropout ratio of attention weight. proj_drop: Dropout ratio of output. """ super().__init__() self.window_size = to_2tuple(window_size) # Wh, Ww win_h, win_w = self.window_size self.window_area = win_h * win_w self.num_heads = num_heads head_dim = head_dim or dim // num_heads attn_dim = head_dim * num_heads self.scale = head_dim ** -0.5 # define a parameter table of relative position bias, shape: 2*Wh-1 * 2*Ww-1, nH self.relative_position_bias_table = nn.Parameter(torch.zeros((2 * win_h - 1) * (2 * win_w - 1), num_heads)) # get pair-wise relative position index for each token inside the window self.register_buffer("relative_position_index", get_relative_position_index(win_h, win_w), persistent=False) self.qkv = nn.Linear(dim, attn_dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(attn_dim, dim) self.proj_drop = nn.Dropout(proj_drop) trunc_normal_(self.relative_position_bias_table, std=.02) self.softmax = nn.Softmax(dim=-1) def _get_rel_pos_bias(self) -> torch.Tensor: relative_position_bias = self.relative_position_bias_table[ self.relative_position_index.view(-1)].view(self.window_area, self.window_area, -1) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww return relative_position_bias.unsqueeze(0) def forward(self, x, mask: Optional[torch.Tensor] = None): """ Args: x: input features with shape of (num_windows*B, N, C) mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None """ B_, N, C = x.shape qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) q = q * self.scale attn = q @ k.transpose(-2, -1) attn = attn + self._get_rel_pos_bias() if mask is not None: num_win = mask.shape[0] attn = attn.view(-1, num_win, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) attn = self.softmax(attn) attn = self.attn_drop(attn) x = attn @ v x = x.transpose(1, 2).reshape(B_, N, -1) x = self.proj(x) x = self.proj_drop(x) return x class SwinTransformerBlock(nn.Module): """ Swin Transformer Block. """ def __init__( self, dim: int, hidden_dim:int, input_resolution: _int_or_tuple_2_t, num_heads: int = 4, head_dim: Optional[int] = None, window_size: _int_or_tuple_2_t = 7, shift_size: int = 0, mlp_ratio: float = 1., qkv_bias: bool = True, proj_drop: float = 0., attn_drop: float = 0., drop_path: float = 0., act_layer: Callable = nn.GELU, norm_layer: Callable = nn.LayerNorm ): """ Args: dim: Number of input channels. input_resolution: Input resolution. window_size: Window size. num_heads: Number of attention heads. head_dim: Enforce the number of channels per head shift_size: Shift size for SW-MSA. mlp_ratio: Ratio of mlp hidden dim to embedding dim. qkv_bias: If True, add a learnable bias to query, key, value. proj_drop: Dropout rate. attn_drop: Attention dropout rate. drop_path: Stochastic depth rate. act_layer: Activation layer. norm_layer: Normalization layer. """ super().__init__() self.dim = dim self.hidden_dim = hidden_dim self.input_resolution = input_resolution ws, ss = self._calc_window_shift(window_size, shift_size) self.window_size: Tuple[int, int] = ws self.shift_size: Tuple[int, int] = ss self.window_area = self.window_size[0] * self.window_size[1] self.mlp_ratio = mlp_ratio self.downsample = nn.Linear(dim,hidden_dim,bias=False) self.norm1 = norm_layer(hidden_dim) self.attn = WindowAttention( hidden_dim, num_heads=num_heads, head_dim=head_dim, window_size=to_2tuple(self.window_size), qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=proj_drop, ) self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(hidden_dim) self.mlp = Mlp( in_features=hidden_dim, hidden_features=int(hidden_dim * 2), out_features = 1664, act_layer=act_layer, drop=proj_drop, ) self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() attn_mask = self.calc_attn(self.input_resolution) self.register_buffer("attn_mask", attn_mask, persistent=False) def calc_attn(self,input_resolution): H, W = input_resolution H = math.ceil(H / self.window_size[0]) * self.window_size[0] W = math.ceil(W / self.window_size[1]) * self.window_size[1] img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 cnt = 0 for h in ( slice(0, -self.window_size[0]), slice(-self.window_size[0], -self.shift_size[0]), slice(-self.shift_size[0], None)): for w in ( slice(0, -self.window_size[1]), slice(-self.window_size[1], -self.shift_size[1]), slice(-self.shift_size[1], None)): img_mask[:, h, w, :] = cnt cnt += 1 mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 mask_windows = mask_windows.view(-1, self.window_area) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) return attn_mask def _calc_window_shift(self, target_window_size, target_shift_size) -> Tuple[Tuple[int, int], Tuple[int, int]]: target_window_size = to_2tuple(target_window_size) target_shift_size = to_2tuple(target_shift_size) window_size = [r if r <= w else w for r, w in zip(self.input_resolution, target_window_size)] shift_size = [0 if r <= w else s for r, w, s in zip(self.input_resolution, window_size, target_shift_size)] return tuple(window_size), tuple(shift_size) def _attn(self, x): B, H, W, C = x.shape # cyclic shift has_shift = any(self.shift_size) if has_shift: shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2)) else: shifted_x = x # pad for resolution not divisible by window size pad_h = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0] pad_w = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1] shifted_x = torch.nn.functional.pad(shifted_x, (0, 0, 0, pad_w, 0, pad_h)) Hp, Wp = H + pad_h, W + pad_w # partition windows x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C x_windows = x_windows.view(-1, self.window_area, C) # nW*B, window_size*window_size, C attn_mask = self.attn_mask # W-MSA/SW-MSA attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C # merge windows attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C) shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C shifted_x = shifted_x[:, :H, :W, :].contiguous() # reverse cyclic shift if has_shift: x = torch.roll(shifted_x, shifts=self.shift_size, dims=(1, 2)) else: x = shifted_x return x def forward(self, x): x = self.downsample(x) B, H, W, C = x.shape # C = hidden_dim x = x + self.drop_path1(self._attn(self.norm1(x))) x = x.reshape(B, -1, C) x = self.drop_path2(self.mlp(self.norm2(x))) x = x.reshape(B, H, W, self.dim) return x def get_abs_pos(abs_pos, tgt_size): # abs_pos: L, C # tgt_size: M # return: M, C src_size = int(math.sqrt(abs_pos.size(1))) tgt_size = int(math.sqrt(tgt_size)) dtype = abs_pos.dtype if src_size != tgt_size: return F.interpolate( abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2), size=(tgt_size, tgt_size), mode="bicubic", align_corners=False, ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype) else: return abs_pos class CrossWindowAttention(nn.Module): """ Patch Merging Layer. """ def __init__( self, image_size, dim, hidden_dim, head, window_size = 12 ): """ Args: dim: Number of input channels. out_dim: Number of output channels (or 2 * dim if None) norm_layer: Normalization layer. """ super().__init__() if isinstance(image_size, tuple) or isinstance(image_size, list): self.image_size = image_size else: self.image_size = (image_size,image_size) self.dim = dim self.window_size = window_size self.shift_size = window_size // 2 self.position_embedding = nn.Parameter(torch.zeros(1, self.image_size[0]*self.image_size[1], 1664)) trunc_normal_(self.position_embedding, std=.02) self.shift_attn = SwinTransformerBlock(dim=dim,hidden_dim=hidden_dim,input_resolution = self.image_size,num_heads=head,window_size =self.window_size,shift_size=self.shift_size) def forward(self, x,image_size): # X bxcxgxg B,C,G,_=x.shape x = x.reshape(x.shape[0],x.shape[1],-1) #bxcxl x = x.permute(0,2,1) #bxlxc B, L, C = x.shape residual = x H,W = image_size pos_embed = get_abs_pos(self.position_embedding,x.size(1)) x = x + pos_embed x = x.view(B,H,W,C) x = self.shift_attn(x) x = x.view(B,-1,C) x = x + residual x = x.permute(0,2,1) #bxcxl x = x.reshape(x.shape[0],x.shape[1],G,G) #bxcxl return x