Unverified Commit 5521e9d0 authored by Local State's avatar Local State Committed by GitHub
Browse files

Add SwinV2 (#6246)



* init submit

* fix typo

* support ufmt and mypy

* fix 2 unittest errors

* fix ufmt issue

* Apply suggestions from code review
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>

* unify codes

* fix meshgrid indexing

* fix a bug

* fix type check

* add type_annotation

* add slow model

* fix device issue

* fix ufmt issue

* add expect pickle file

* fix jit script issue

* fix type check

* keep consistent argument order

* add support for pretrained_window_size

* avoid code duplication

* a better code reuse

* update window_size argument

* make permute and flatten operations modular

* add PatchMergingV2

* modify expect.pkl

* use None as default argument value

* fix type check

* fix indent

* fix window_size (temporarily)

* remove "v2_" related prefix and add v2 builder

* remove v2 builder

* keep default value consistent with official repo

* deprecate dropout

* deprecate pretrained_window_size

* fix dynamic padding edge case

* remove unused imports

* remove doc modification

* Revert "deprecate dropout"

This reverts commit 8a13f932815ae25655c07430d52929f86b1ca479.

* Revert "fix dynamic padding edge case"

This reverts commit 1c7579cb1bd7bf2f0f94907f39bee6ed707a97a8.

* remove unused kwargs

* add downsample docs

* revert block default value

* revert argument order change

* explicitly specify start_dim

* add small and base variants

* add expect files and slow_models

* Add model weights and documentation for swin v2

* fix lint

* fix end of files line
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
Co-authored-by: default avatarJoao Gomes <jdsgomes@fb.com>
parent 7e8186e0
...@@ -3,16 +3,18 @@ SwinTransformer ...@@ -3,16 +3,18 @@ SwinTransformer
.. currentmodule:: torchvision.models .. currentmodule:: torchvision.models
The SwinTransformer model is based on the `Swin Transformer: Hierarchical Vision The SwinTransformer models are based on the `Swin Transformer: Hierarchical Vision
Transformer using Shifted Windows <https://arxiv.org/abs/2103.14030>`__ Transformer using Shifted Windows <https://arxiv.org/abs/2103.14030>`__
paper. paper.
SwinTransformer V2 models are based on the `Swin Transformer V2: Scaling Up Capacity
and Resolution <https://openaccess.thecvf.com/content/CVPR2022/papers/Liu_Swin_Transformer_V2_Scaling_Up_Capacity_and_Resolution_CVPR_2022_paper.pdf>`__
paper.
Model builders Model builders
-------------- --------------
The following model builders can be used to instantiate an SwinTransformer model. The following model builders can be used to instantiate an SwinTransformer model (original and V2) with and without pre-trained weights.
`swin_t` can be instantiated with pre-trained weights and all others without.
All the model builders internally rely on the ``torchvision.models.swin_transformer.SwinTransformer`` All the model builders internally rely on the ``torchvision.models.swin_transformer.SwinTransformer``
base class. Please refer to the `source code base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py>`_ for <https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py>`_ for
...@@ -25,3 +27,6 @@ more details about this class. ...@@ -25,3 +27,6 @@ more details about this class.
swin_t swin_t
swin_s swin_s
swin_b swin_b
swin_v2_t
swin_v2_s
swin_v2_b
...@@ -236,6 +236,17 @@ Note that `--val-resize-size` was optimized in a post-training step, see their ` ...@@ -236,6 +236,17 @@ Note that `--val-resize-size` was optimized in a post-training step, see their `
### SwinTransformer V2
```
torchrun --nproc_per_node=8 train.py\
--model $MODEL --epochs 300 --batch-size 128 --opt adamw --lr 0.001 --weight-decay 0.05 --norm-weight-decay 0.0 --bias-weight-decay 0.0 --transformer-embedding-decay 0.0 --lr-scheduler cosineannealinglr --lr-min 0.00001 --lr-warmup-method linear --lr-warmup-epochs 20 --lr-warmup-decay 0.01 --amp --label-smoothing 0.1 --mixup-alpha 0.8 --clip-grad-norm 5.0 --cutmix-alpha 1.0 --random-erase 0.25 --interpolation bicubic --auto-augment ta_wide --model-ema --ra-sampler --ra-reps 4 --val-resize-size 256 --val-crop-size 256 --train-crop-size 256
```
Here `$MODEL` is one of `swin_v2_t`, `swin_v2_s` or `swin_v2_b`.
Note that `--val-resize-size` was optimized in a post-training step, see their `Weights` entry for the exact value.
### ShuffleNet V2 ### ShuffleNet V2
``` ```
torchrun --nproc_per_node=8 train.py \ torchrun --nproc_per_node=8 train.py \
......
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
...@@ -332,6 +332,9 @@ slow_models = [ ...@@ -332,6 +332,9 @@ slow_models = [
"swin_t", "swin_t",
"swin_s", "swin_s",
"swin_b", "swin_b",
"swin_v2_t",
"swin_v2_s",
"swin_v2_b",
] ]
for m in slow_models: for m in slow_models:
_model_params[m] = {"input_shape": (1, 3, 64, 64)} _model_params[m] = {"input_shape": (1, 3, 64, 64)}
......
import math
from functools import partial from functools import partial
from typing import Any, Callable, List, Optional from typing import Any, Callable, List, Optional
...@@ -19,21 +20,45 @@ __all__ = [ ...@@ -19,21 +20,45 @@ __all__ = [
"Swin_T_Weights", "Swin_T_Weights",
"Swin_S_Weights", "Swin_S_Weights",
"Swin_B_Weights", "Swin_B_Weights",
"Swin_V2_T_Weights",
"Swin_V2_S_Weights",
"Swin_V2_B_Weights",
"swin_t", "swin_t",
"swin_s", "swin_s",
"swin_b", "swin_b",
"swin_v2_t",
"swin_v2_s",
"swin_v2_b",
] ]
def _patch_merging_pad(x): def _patch_merging_pad(x: torch.Tensor) -> torch.Tensor:
H, W, _ = x.shape[-3:] H, W, _ = x.shape[-3:]
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
x0 = x[..., 0::2, 0::2, :] # ... H/2 W/2 C
x1 = x[..., 1::2, 0::2, :] # ... H/2 W/2 C
x2 = x[..., 0::2, 1::2, :] # ... H/2 W/2 C
x3 = x[..., 1::2, 1::2, :] # ... H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # ... H/2 W/2 4*C
return x return x
torch.fx.wrap("_patch_merging_pad") torch.fx.wrap("_patch_merging_pad")
def _get_relative_position_bias(
relative_position_bias_table: torch.Tensor, relative_position_index: torch.Tensor, window_size: List[int]
) -> torch.Tensor:
N = window_size[0] * window_size[1]
relative_position_bias = relative_position_bias_table[relative_position_index] # type: ignore[index]
relative_position_bias = relative_position_bias.view(N, N, -1)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0)
return relative_position_bias
torch.fx.wrap("_get_relative_position_bias")
class PatchMerging(nn.Module): class PatchMerging(nn.Module):
"""Patch Merging Layer. """Patch Merging Layer.
Args: Args:
...@@ -56,15 +81,35 @@ class PatchMerging(nn.Module): ...@@ -56,15 +81,35 @@ class PatchMerging(nn.Module):
Tensor with layout of [..., H/2, W/2, 2*C] Tensor with layout of [..., H/2, W/2, 2*C]
""" """
x = _patch_merging_pad(x) x = _patch_merging_pad(x)
x = self.norm(x)
x = self.reduction(x) # ... H/2 W/2 2*C
return x
x0 = x[..., 0::2, 0::2, :] # ... H/2 W/2 C
x1 = x[..., 1::2, 0::2, :] # ... H/2 W/2 C
x2 = x[..., 0::2, 1::2, :] # ... H/2 W/2 C
x3 = x[..., 1::2, 1::2, :] # ... H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # ... H/2 W/2 4*C
x = self.norm(x) class PatchMergingV2(nn.Module):
"""Patch Merging Layer for Swin Transformer V2.
Args:
dim (int): Number of input channels.
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
"""
def __init__(self, dim: int, norm_layer: Callable[..., nn.Module] = nn.LayerNorm):
super().__init__()
_log_api_usage_once(self)
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(2 * dim) # difference
def forward(self, x: Tensor):
"""
Args:
x (Tensor): input tensor with expected layout of [..., H, W, C]
Returns:
Tensor with layout of [..., H/2, W/2, 2*C]
"""
x = _patch_merging_pad(x)
x = self.reduction(x) # ... H/2 W/2 2*C x = self.reduction(x) # ... H/2 W/2 2*C
x = self.norm(x)
return x return x
...@@ -80,6 +125,7 @@ def shifted_window_attention( ...@@ -80,6 +125,7 @@ def shifted_window_attention(
dropout: float = 0.0, dropout: float = 0.0,
qkv_bias: Optional[Tensor] = None, qkv_bias: Optional[Tensor] = None,
proj_bias: Optional[Tensor] = None, proj_bias: Optional[Tensor] = None,
logit_scale: Optional[torch.Tensor] = None,
): ):
""" """
Window based multi-head self attention (W-MSA) module with relative position bias. Window based multi-head self attention (W-MSA) module with relative position bias.
...@@ -96,6 +142,7 @@ def shifted_window_attention( ...@@ -96,6 +142,7 @@ def shifted_window_attention(
dropout (float): Dropout ratio of output. Default: 0.0. dropout (float): Dropout ratio of output. Default: 0.0.
qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None. qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None.
proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None. proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None.
logit_scale (Tensor[out_dim], optional): Logit scale of cosine attention for Swin Transformer V2. Default: None.
Returns: Returns:
Tensor[N, H, W, C]: The output tensor after shifted window attention. Tensor[N, H, W, C]: The output tensor after shifted window attention.
""" """
...@@ -123,11 +170,21 @@ def shifted_window_attention( ...@@ -123,11 +170,21 @@ def shifted_window_attention(
x = x.permute(0, 1, 3, 2, 4, 5).reshape(B * num_windows, window_size[0] * window_size[1], C) # B*nW, Ws*Ws, C x = x.permute(0, 1, 3, 2, 4, 5).reshape(B * num_windows, window_size[0] * window_size[1], C) # B*nW, Ws*Ws, C
# multi-head attention # multi-head attention
if logit_scale is not None and qkv_bias is not None:
qkv_bias = qkv_bias.clone()
length = qkv_bias.numel() // 3
qkv_bias[length : 2 * length].zero_()
qkv = F.linear(x, qkv_weight, qkv_bias) qkv = F.linear(x, qkv_weight, qkv_bias)
qkv = qkv.reshape(x.size(0), x.size(1), 3, num_heads, C // num_heads).permute(2, 0, 3, 1, 4) qkv = qkv.reshape(x.size(0), x.size(1), 3, num_heads, C // num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] q, k, v = qkv[0], qkv[1], qkv[2]
q = q * (C // num_heads) ** -0.5 if logit_scale is not None:
attn = q.matmul(k.transpose(-2, -1)) # cosine attention
attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
logit_scale = torch.clamp(logit_scale, max=math.log(100.0)).exp()
attn = attn * logit_scale
else:
q = q * (C // num_heads) ** -0.5
attn = q.matmul(k.transpose(-2, -1))
# add relative position bias # add relative position bias
attn = attn + relative_position_bias attn = attn + relative_position_bias
...@@ -200,11 +257,17 @@ class ShiftedWindowAttention(nn.Module): ...@@ -200,11 +257,17 @@ class ShiftedWindowAttention(nn.Module):
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim, bias=proj_bias) self.proj = nn.Linear(dim, dim, bias=proj_bias)
self.define_relative_position_bias_table()
self.define_relative_position_index()
def define_relative_position_bias_table(self):
# define a parameter table of relative position bias # define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter( self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), self.num_heads)
) # 2*Wh-1 * 2*Ww-1, nH ) # 2*Wh-1 * 2*Ww-1, nH
nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)
def define_relative_position_index(self):
# get pair-wise relative position index for each token inside the window # get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0]) coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1]) coords_w = torch.arange(self.window_size[1])
...@@ -215,10 +278,13 @@ class ShiftedWindowAttention(nn.Module): ...@@ -215,10 +278,13 @@ class ShiftedWindowAttention(nn.Module):
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww*Wh*Ww relative_position_index = relative_coords.sum(-1).flatten() # Wh*Ww*Wh*Ww
self.register_buffer("relative_position_index", relative_position_index) self.register_buffer("relative_position_index", relative_position_index)
nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02) def get_relative_position_bias(self) -> torch.Tensor:
return _get_relative_position_bias(
self.relative_position_bias_table, self.relative_position_index, self.window_size # type: ignore[arg-type]
)
def forward(self, x: Tensor): def forward(self, x: Tensor):
""" """
...@@ -227,12 +293,91 @@ class ShiftedWindowAttention(nn.Module): ...@@ -227,12 +293,91 @@ class ShiftedWindowAttention(nn.Module):
Returns: Returns:
Tensor with same layout as input, i.e. [B, H, W, C] Tensor with same layout as input, i.e. [B, H, W, C]
""" """
relative_position_bias = self.get_relative_position_bias()
return shifted_window_attention(
x,
self.qkv.weight,
self.proj.weight,
relative_position_bias,
self.window_size,
self.num_heads,
shift_size=self.shift_size,
attention_dropout=self.attention_dropout,
dropout=self.dropout,
qkv_bias=self.qkv.bias,
proj_bias=self.proj.bias,
)
class ShiftedWindowAttentionV2(ShiftedWindowAttention):
"""
See :func:`shifted_window_attention_v2`.
"""
def __init__(
self,
dim: int,
window_size: List[int],
shift_size: List[int],
num_heads: int,
qkv_bias: bool = True,
proj_bias: bool = True,
attention_dropout: float = 0.0,
dropout: float = 0.0,
):
super().__init__(
dim,
window_size,
shift_size,
num_heads,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
attention_dropout=attention_dropout,
dropout=dropout,
)
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
# mlp to generate continuous relative position bias
self.cpb_mlp = nn.Sequential(
nn.Linear(2, 512, bias=True), nn.ReLU(inplace=True), nn.Linear(512, num_heads, bias=False)
)
if qkv_bias:
length = self.qkv.bias.numel() // 3
self.qkv.bias[length : 2 * length].data.zero_()
def define_relative_position_bias_table(self):
# get relative_coords_table
relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w], indexing="ij"))
relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1
relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1
relative_coords_table *= 8 # normalize to -8, 8
relative_coords_table = (
torch.sign(relative_coords_table) * torch.log2(torch.abs(relative_coords_table) + 1.0) / 3.0
)
self.register_buffer("relative_coords_table", relative_coords_table)
N = self.window_size[0] * self.window_size[1] def get_relative_position_bias(self) -> torch.Tensor:
relative_position_bias = self.relative_position_bias_table[self.relative_position_index] # type: ignore[index] relative_position_bias = _get_relative_position_bias(
relative_position_bias = relative_position_bias.view(N, N, -1) self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads),
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0) self.relative_position_index, # type: ignore[arg-type]
self.window_size,
)
relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
return relative_position_bias
def forward(self, x: Tensor):
"""
Args:
x (Tensor): Tensor with layout of [B, H, W, C]
Returns:
Tensor with same layout as input, i.e. [B, H, W, C]
"""
relative_position_bias = self.get_relative_position_bias()
return shifted_window_attention( return shifted_window_attention(
x, x,
self.qkv.weight, self.qkv.weight,
...@@ -245,6 +390,7 @@ class ShiftedWindowAttention(nn.Module): ...@@ -245,6 +390,7 @@ class ShiftedWindowAttention(nn.Module):
dropout=self.dropout, dropout=self.dropout,
qkv_bias=self.qkv.bias, qkv_bias=self.qkv.bias,
proj_bias=self.proj.bias, proj_bias=self.proj.bias,
logit_scale=self.logit_scale,
) )
...@@ -305,6 +451,54 @@ class SwinTransformerBlock(nn.Module): ...@@ -305,6 +451,54 @@ class SwinTransformerBlock(nn.Module):
return x return x
class SwinTransformerBlockV2(SwinTransformerBlock):
"""
Swin Transformer V2 Block.
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (List[int]): Window size.
shift_size (List[int]): Shift size for shifted window attention.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
dropout (float): Dropout rate. Default: 0.0.
attention_dropout (float): Attention dropout rate. Default: 0.0.
stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0.
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
attn_layer (nn.Module): Attention layer. Default: ShiftedWindowAttentionV2.
"""
def __init__(
self,
dim: int,
num_heads: int,
window_size: List[int],
shift_size: List[int],
mlp_ratio: float = 4.0,
dropout: float = 0.0,
attention_dropout: float = 0.0,
stochastic_depth_prob: float = 0.0,
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
attn_layer: Callable[..., nn.Module] = ShiftedWindowAttentionV2,
):
super().__init__(
dim,
num_heads,
window_size,
shift_size,
mlp_ratio=mlp_ratio,
dropout=dropout,
attention_dropout=attention_dropout,
stochastic_depth_prob=stochastic_depth_prob,
norm_layer=norm_layer,
attn_layer=attn_layer,
)
def forward(self, x: Tensor):
x = x + self.stochastic_depth(self.norm1(self.attn(x)))
x = x + self.stochastic_depth(self.norm2(self.mlp(x)))
return x
class SwinTransformer(nn.Module): class SwinTransformer(nn.Module):
""" """
Implements Swin Transformer from the `"Swin Transformer: Hierarchical Vision Transformer using Implements Swin Transformer from the `"Swin Transformer: Hierarchical Vision Transformer using
...@@ -318,10 +512,11 @@ class SwinTransformer(nn.Module): ...@@ -318,10 +512,11 @@ class SwinTransformer(nn.Module):
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
dropout (float): Dropout rate. Default: 0.0. dropout (float): Dropout rate. Default: 0.0.
attention_dropout (float): Attention dropout rate. Default: 0.0. attention_dropout (float): Attention dropout rate. Default: 0.0.
stochastic_depth_prob (float): Stochastic depth rate. Default: 0.0. stochastic_depth_prob (float): Stochastic depth rate. Default: 0.1.
num_classes (int): Number of classes for classification head. Default: 1000. num_classes (int): Number of classes for classification head. Default: 1000.
block (nn.Module, optional): SwinTransformer Block. Default: None. block (nn.Module, optional): SwinTransformer Block. Default: None.
norm_layer (nn.Module, optional): Normalization layer. Default: None. norm_layer (nn.Module, optional): Normalization layer. Default: None.
downsample_layer (nn.Module): Downsample layer (patch merging). Default: PatchMerging.
""" """
def __init__( def __init__(
...@@ -334,10 +529,11 @@ class SwinTransformer(nn.Module): ...@@ -334,10 +529,11 @@ class SwinTransformer(nn.Module):
mlp_ratio: float = 4.0, mlp_ratio: float = 4.0,
dropout: float = 0.0, dropout: float = 0.0,
attention_dropout: float = 0.0, attention_dropout: float = 0.0,
stochastic_depth_prob: float = 0.0, stochastic_depth_prob: float = 0.1,
num_classes: int = 1000, num_classes: int = 1000,
norm_layer: Optional[Callable[..., nn.Module]] = None, norm_layer: Optional[Callable[..., nn.Module]] = None,
block: Optional[Callable[..., nn.Module]] = None, block: Optional[Callable[..., nn.Module]] = None,
downsample_layer: Callable[..., nn.Module] = PatchMerging,
): ):
super().__init__() super().__init__()
_log_api_usage_once(self) _log_api_usage_once(self)
...@@ -345,7 +541,6 @@ class SwinTransformer(nn.Module): ...@@ -345,7 +541,6 @@ class SwinTransformer(nn.Module):
if block is None: if block is None:
block = SwinTransformerBlock block = SwinTransformerBlock
if norm_layer is None: if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-5) norm_layer = partial(nn.LayerNorm, eps=1e-5)
...@@ -387,12 +582,14 @@ class SwinTransformer(nn.Module): ...@@ -387,12 +582,14 @@ class SwinTransformer(nn.Module):
layers.append(nn.Sequential(*stage)) layers.append(nn.Sequential(*stage))
# add patch merging layer # add patch merging layer
if i_stage < (len(depths) - 1): if i_stage < (len(depths) - 1):
layers.append(PatchMerging(dim, norm_layer)) layers.append(downsample_layer(dim, norm_layer))
self.features = nn.Sequential(*layers) self.features = nn.Sequential(*layers)
num_features = embed_dim * 2 ** (len(depths) - 1) num_features = embed_dim * 2 ** (len(depths) - 1)
self.norm = norm_layer(num_features) self.norm = norm_layer(num_features)
self.permute = Permute([0, 3, 1, 2])
self.avgpool = nn.AdaptiveAvgPool2d(1) self.avgpool = nn.AdaptiveAvgPool2d(1)
self.flatten = nn.Flatten(1)
self.head = nn.Linear(num_features, num_classes) self.head = nn.Linear(num_features, num_classes)
for m in self.modules(): for m in self.modules():
...@@ -404,9 +601,9 @@ class SwinTransformer(nn.Module): ...@@ -404,9 +601,9 @@ class SwinTransformer(nn.Module):
def forward(self, x): def forward(self, x):
x = self.features(x) x = self.features(x)
x = self.norm(x) x = self.norm(x)
x = x.permute(0, 3, 1, 2) x = self.permute(x)
x = self.avgpool(x) x = self.avgpool(x)
x = torch.flatten(x, 1) x = self.flatten(x)
x = self.head(x) x = self.head(x)
return x return x
...@@ -515,6 +712,75 @@ class Swin_B_Weights(WeightsEnum): ...@@ -515,6 +712,75 @@ class Swin_B_Weights(WeightsEnum):
DEFAULT = IMAGENET1K_V1 DEFAULT = IMAGENET1K_V1
class Swin_V2_T_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/swin_v2_t-b137f0e2.pth",
transforms=partial(
ImageClassification, crop_size=256, resize_size=260, interpolation=InterpolationMode.BICUBIC
),
meta={
**_COMMON_META,
"num_params": 28351570,
"min_size": (256, 256),
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer-v2",
"_metrics": {
"ImageNet-1K": {
"acc@1": 82.072,
"acc@5": 96.132,
}
},
"_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""",
},
)
DEFAULT = IMAGENET1K_V1
class Swin_V2_S_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/swin_v2_s-637d8ceb.pth",
transforms=partial(
ImageClassification, crop_size=256, resize_size=260, interpolation=InterpolationMode.BICUBIC
),
meta={
**_COMMON_META,
"num_params": 49737442,
"min_size": (256, 256),
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer-v2",
"_metrics": {
"ImageNet-1K": {
"acc@1": 83.712,
"acc@5": 96.816,
}
},
"_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""",
},
)
DEFAULT = IMAGENET1K_V1
class Swin_V2_B_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/swin_v2_b-781e5279.pth",
transforms=partial(
ImageClassification, crop_size=256, resize_size=272, interpolation=InterpolationMode.BICUBIC
),
meta={
**_COMMON_META,
"num_params": 87930848,
"min_size": (256, 256),
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer-v2",
"_metrics": {
"ImageNet-1K": {
"acc@1": 84.112,
"acc@5": 96.864,
}
},
"_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""",
},
)
DEFAULT = IMAGENET1K_V1
@register_model() @register_model()
def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer:
""" """
...@@ -624,3 +890,120 @@ def swin_b(*, weights: Optional[Swin_B_Weights] = None, progress: bool = True, * ...@@ -624,3 +890,120 @@ def swin_b(*, weights: Optional[Swin_B_Weights] = None, progress: bool = True, *
progress=progress, progress=progress,
**kwargs, **kwargs,
) )
@register_model()
def swin_v2_t(*, weights: Optional[Swin_V2_T_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer:
"""
Constructs a swin_v2_tiny architecture from
`Swin Transformer V2: Scaling Up Capacity and Resolution <https://arxiv.org/pdf/2111.09883>`_.
Args:
weights (:class:`~torchvision.models.Swin_V2_T_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.Swin_V2_T_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py>`_
for more details about this class.
.. autoclass:: torchvision.models.Swin_V2_T_Weights
:members:
"""
weights = Swin_V2_T_Weights.verify(weights)
return _swin_transformer(
patch_size=[4, 4],
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=[8, 8],
stochastic_depth_prob=0.2,
weights=weights,
progress=progress,
block=SwinTransformerBlockV2,
downsample_layer=PatchMergingV2,
**kwargs,
)
@register_model()
def swin_v2_s(*, weights: Optional[Swin_V2_S_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer:
"""
Constructs a swin_v2_small architecture from
`Swin Transformer V2: Scaling Up Capacity and Resolution <https://arxiv.org/pdf/2111.09883>`_.
Args:
weights (:class:`~torchvision.models.Swin_V2_S_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.Swin_V2_S_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py>`_
for more details about this class.
.. autoclass:: torchvision.models.Swin_V2_S_Weights
:members:
"""
weights = Swin_V2_S_Weights.verify(weights)
return _swin_transformer(
patch_size=[4, 4],
embed_dim=96,
depths=[2, 2, 18, 2],
num_heads=[3, 6, 12, 24],
window_size=[8, 8],
stochastic_depth_prob=0.3,
weights=weights,
progress=progress,
block=SwinTransformerBlockV2,
downsample_layer=PatchMergingV2,
**kwargs,
)
@register_model()
def swin_v2_b(*, weights: Optional[Swin_V2_B_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer:
"""
Constructs a swin_v2_base architecture from
`Swin Transformer V2: Scaling Up Capacity and Resolution <https://arxiv.org/pdf/2111.09883>`_.
Args:
weights (:class:`~torchvision.models.Swin_V2_B_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.Swin_V2_B_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py>`_
for more details about this class.
.. autoclass:: torchvision.models.Swin_V2_B_Weights
:members:
"""
weights = Swin_V2_B_Weights.verify(weights)
return _swin_transformer(
patch_size=[4, 4],
embed_dim=128,
depths=[2, 2, 18, 2],
num_heads=[4, 8, 16, 32],
window_size=[8, 8],
stochastic_depth_prob=0.5,
weights=weights,
progress=progress,
block=SwinTransformerBlockV2,
downsample_layer=PatchMergingV2,
**kwargs,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment