Unverified Commit 5521e9d0 authored by Local State's avatar Local State Committed by GitHub
Browse files

Add SwinV2 (#6246)



* init submit

* fix typo

* support ufmt and mypy

* fix 2 unittest errors

* fix ufmt issue

* Apply suggestions from code review
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>

* unify codes

* fix meshgrid indexing

* fix a bug

* fix type check

* add type_annotation

* add slow model

* fix device issue

* fix ufmt issue

* add expect pickle file

* fix jit script issue

* fix type check

* keep consistent argument order

* add support for pretrained_window_size

* avoid code duplication

* a better code reuse

* update window_size argument

* make permute and flatten operations modular

* add PatchMergingV2

* modify expect.pkl

* use None as default argument value

* fix type check

* fix indent

* fix window_size (temporarily)

* remove "v2_" related prefix and add v2 builder

* remove v2 builder

* keep default value consistent with official repo

* deprecate dropout

* deprecate pretrained_window_size

* fix dynamic padding edge case

* remove unused imports

* remove doc modification

* Revert "deprecate dropout"

This reverts commit 8a13f932815ae25655c07430d52929f86b1ca479.

* Revert "fix dynamic padding edge case"

This reverts commit 1c7579cb1bd7bf2f0f94907f39bee6ed707a97a8.

* remove unused kwargs

* add downsample docs

* revert block default value

* revert argument order change

* explicitly specify start_dim

* add small and base variants

* add expect files and slow_models

* Add model weights and documentation for swin v2

* fix lint

* fix end of files line
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
Co-authored-by: default avatarJoao Gomes <jdsgomes@fb.com>
parent 7e8186e0
......@@ -3,16 +3,18 @@ SwinTransformer
.. currentmodule:: torchvision.models
The SwinTransformer model is based on the `Swin Transformer: Hierarchical Vision
The SwinTransformer models are based on the `Swin Transformer: Hierarchical Vision
Transformer using Shifted Windows <https://arxiv.org/abs/2103.14030>`__
paper.
SwinTransformer V2 models are based on the `Swin Transformer V2: Scaling Up Capacity
and Resolution <https://openaccess.thecvf.com/content/CVPR2022/papers/Liu_Swin_Transformer_V2_Scaling_Up_Capacity_and_Resolution_CVPR_2022_paper.pdf>`__
paper.
Model builders
--------------
The following model builders can be used to instantiate an SwinTransformer model.
`swin_t` can be instantiated with pre-trained weights and all others without.
The following model builders can be used to instantiate an SwinTransformer model (original and V2) with and without pre-trained weights.
All the model builders internally rely on the ``torchvision.models.swin_transformer.SwinTransformer``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py>`_ for
......@@ -25,3 +27,6 @@ more details about this class.
swin_t
swin_s
swin_b
swin_v2_t
swin_v2_s
swin_v2_b
......@@ -236,6 +236,17 @@ Note that `--val-resize-size` was optimized in a post-training step, see their `
### SwinTransformer V2
```
torchrun --nproc_per_node=8 train.py\
--model $MODEL --epochs 300 --batch-size 128 --opt adamw --lr 0.001 --weight-decay 0.05 --norm-weight-decay 0.0 --bias-weight-decay 0.0 --transformer-embedding-decay 0.0 --lr-scheduler cosineannealinglr --lr-min 0.00001 --lr-warmup-method linear --lr-warmup-epochs 20 --lr-warmup-decay 0.01 --amp --label-smoothing 0.1 --mixup-alpha 0.8 --clip-grad-norm 5.0 --cutmix-alpha 1.0 --random-erase 0.25 --interpolation bicubic --auto-augment ta_wide --model-ema --ra-sampler --ra-reps 4 --val-resize-size 256 --val-crop-size 256 --train-crop-size 256
```
Here `$MODEL` is one of `swin_v2_t`, `swin_v2_s` or `swin_v2_b`.
Note that `--val-resize-size` was optimized in a post-training step, see their `Weights` entry for the exact value.
### ShuffleNet V2
```
torchrun --nproc_per_node=8 train.py \
......
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
......@@ -332,6 +332,9 @@ slow_models = [
"swin_t",
"swin_s",
"swin_b",
"swin_v2_t",
"swin_v2_s",
"swin_v2_b",
]
for m in slow_models:
_model_params[m] = {"input_shape": (1, 3, 64, 64)}
......
import math
from functools import partial
from typing import Any, Callable, List, Optional
......@@ -19,21 +20,45 @@ __all__ = [
"Swin_T_Weights",
"Swin_S_Weights",
"Swin_B_Weights",
"Swin_V2_T_Weights",
"Swin_V2_S_Weights",
"Swin_V2_B_Weights",
"swin_t",
"swin_s",
"swin_b",
"swin_v2_t",
"swin_v2_s",
"swin_v2_b",
]
def _patch_merging_pad(x):
def _patch_merging_pad(x: torch.Tensor) -> torch.Tensor:
H, W, _ = x.shape[-3:]
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
x0 = x[..., 0::2, 0::2, :] # ... H/2 W/2 C
x1 = x[..., 1::2, 0::2, :] # ... H/2 W/2 C
x2 = x[..., 0::2, 1::2, :] # ... H/2 W/2 C
x3 = x[..., 1::2, 1::2, :] # ... H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # ... H/2 W/2 4*C
return x
torch.fx.wrap("_patch_merging_pad")
def _get_relative_position_bias(
relative_position_bias_table: torch.Tensor, relative_position_index: torch.Tensor, window_size: List[int]
) -> torch.Tensor:
N = window_size[0] * window_size[1]
relative_position_bias = relative_position_bias_table[relative_position_index] # type: ignore[index]
relative_position_bias = relative_position_bias.view(N, N, -1)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0)
return relative_position_bias
torch.fx.wrap("_get_relative_position_bias")
class PatchMerging(nn.Module):
"""Patch Merging Layer.
Args:
......@@ -56,15 +81,35 @@ class PatchMerging(nn.Module):
Tensor with layout of [..., H/2, W/2, 2*C]
"""
x = _patch_merging_pad(x)
x = self.norm(x)
x = self.reduction(x) # ... H/2 W/2 2*C
return x
x0 = x[..., 0::2, 0::2, :] # ... H/2 W/2 C
x1 = x[..., 1::2, 0::2, :] # ... H/2 W/2 C
x2 = x[..., 0::2, 1::2, :] # ... H/2 W/2 C
x3 = x[..., 1::2, 1::2, :] # ... H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # ... H/2 W/2 4*C
x = self.norm(x)
class PatchMergingV2(nn.Module):
"""Patch Merging Layer for Swin Transformer V2.
Args:
dim (int): Number of input channels.
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
"""
def __init__(self, dim: int, norm_layer: Callable[..., nn.Module] = nn.LayerNorm):
super().__init__()
_log_api_usage_once(self)
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(2 * dim) # difference
def forward(self, x: Tensor):
"""
Args:
x (Tensor): input tensor with expected layout of [..., H, W, C]
Returns:
Tensor with layout of [..., H/2, W/2, 2*C]
"""
x = _patch_merging_pad(x)
x = self.reduction(x) # ... H/2 W/2 2*C
x = self.norm(x)
return x
......@@ -80,6 +125,7 @@ def shifted_window_attention(
dropout: float = 0.0,
qkv_bias: Optional[Tensor] = None,
proj_bias: Optional[Tensor] = None,
logit_scale: Optional[torch.Tensor] = None,
):
"""
Window based multi-head self attention (W-MSA) module with relative position bias.
......@@ -96,6 +142,7 @@ def shifted_window_attention(
dropout (float): Dropout ratio of output. Default: 0.0.
qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None.
proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None.
logit_scale (Tensor[out_dim], optional): Logit scale of cosine attention for Swin Transformer V2. Default: None.
Returns:
Tensor[N, H, W, C]: The output tensor after shifted window attention.
"""
......@@ -123,11 +170,21 @@ def shifted_window_attention(
x = x.permute(0, 1, 3, 2, 4, 5).reshape(B * num_windows, window_size[0] * window_size[1], C) # B*nW, Ws*Ws, C
# multi-head attention
if logit_scale is not None and qkv_bias is not None:
qkv_bias = qkv_bias.clone()
length = qkv_bias.numel() // 3
qkv_bias[length : 2 * length].zero_()
qkv = F.linear(x, qkv_weight, qkv_bias)
qkv = qkv.reshape(x.size(0), x.size(1), 3, num_heads, C // num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * (C // num_heads) ** -0.5
attn = q.matmul(k.transpose(-2, -1))
if logit_scale is not None:
# cosine attention
attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
logit_scale = torch.clamp(logit_scale, max=math.log(100.0)).exp()
attn = attn * logit_scale
else:
q = q * (C // num_heads) ** -0.5
attn = q.matmul(k.transpose(-2, -1))
# add relative position bias
attn = attn + relative_position_bias
......@@ -200,11 +257,17 @@ class ShiftedWindowAttention(nn.Module):
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim, bias=proj_bias)
self.define_relative_position_bias_table()
self.define_relative_position_index()
def define_relative_position_bias_table(self):
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), self.num_heads)
) # 2*Wh-1 * 2*Ww-1, nH
nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)
def define_relative_position_index(self):
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
......@@ -215,10 +278,13 @@ class ShiftedWindowAttention(nn.Module):
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww*Wh*Ww
relative_position_index = relative_coords.sum(-1).flatten() # Wh*Ww*Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)
def get_relative_position_bias(self) -> torch.Tensor:
return _get_relative_position_bias(
self.relative_position_bias_table, self.relative_position_index, self.window_size # type: ignore[arg-type]
)
def forward(self, x: Tensor):
"""
......@@ -227,12 +293,91 @@ class ShiftedWindowAttention(nn.Module):
Returns:
Tensor with same layout as input, i.e. [B, H, W, C]
"""
relative_position_bias = self.get_relative_position_bias()
return shifted_window_attention(
x,
self.qkv.weight,
self.proj.weight,
relative_position_bias,
self.window_size,
self.num_heads,
shift_size=self.shift_size,
attention_dropout=self.attention_dropout,
dropout=self.dropout,
qkv_bias=self.qkv.bias,
proj_bias=self.proj.bias,
)
class ShiftedWindowAttentionV2(ShiftedWindowAttention):
"""
See :func:`shifted_window_attention_v2`.
"""
def __init__(
self,
dim: int,
window_size: List[int],
shift_size: List[int],
num_heads: int,
qkv_bias: bool = True,
proj_bias: bool = True,
attention_dropout: float = 0.0,
dropout: float = 0.0,
):
super().__init__(
dim,
window_size,
shift_size,
num_heads,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
attention_dropout=attention_dropout,
dropout=dropout,
)
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
# mlp to generate continuous relative position bias
self.cpb_mlp = nn.Sequential(
nn.Linear(2, 512, bias=True), nn.ReLU(inplace=True), nn.Linear(512, num_heads, bias=False)
)
if qkv_bias:
length = self.qkv.bias.numel() // 3
self.qkv.bias[length : 2 * length].data.zero_()
def define_relative_position_bias_table(self):
# get relative_coords_table
relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w], indexing="ij"))
relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1
relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1
relative_coords_table *= 8 # normalize to -8, 8
relative_coords_table = (
torch.sign(relative_coords_table) * torch.log2(torch.abs(relative_coords_table) + 1.0) / 3.0
)
self.register_buffer("relative_coords_table", relative_coords_table)
N = self.window_size[0] * self.window_size[1]
relative_position_bias = self.relative_position_bias_table[self.relative_position_index] # type: ignore[index]
relative_position_bias = relative_position_bias.view(N, N, -1)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0)
def get_relative_position_bias(self) -> torch.Tensor:
relative_position_bias = _get_relative_position_bias(
self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads),
self.relative_position_index, # type: ignore[arg-type]
self.window_size,
)
relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
return relative_position_bias
def forward(self, x: Tensor):
"""
Args:
x (Tensor): Tensor with layout of [B, H, W, C]
Returns:
Tensor with same layout as input, i.e. [B, H, W, C]
"""
relative_position_bias = self.get_relative_position_bias()
return shifted_window_attention(
x,
self.qkv.weight,
......@@ -245,6 +390,7 @@ class ShiftedWindowAttention(nn.Module):
dropout=self.dropout,
qkv_bias=self.qkv.bias,
proj_bias=self.proj.bias,
logit_scale=self.logit_scale,
)
......@@ -305,6 +451,54 @@ class SwinTransformerBlock(nn.Module):
return x
class SwinTransformerBlockV2(SwinTransformerBlock):
"""
Swin Transformer V2 Block.
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (List[int]): Window size.
shift_size (List[int]): Shift size for shifted window attention.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
dropout (float): Dropout rate. Default: 0.0.
attention_dropout (float): Attention dropout rate. Default: 0.0.
stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0.
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
attn_layer (nn.Module): Attention layer. Default: ShiftedWindowAttentionV2.
"""
def __init__(
self,
dim: int,
num_heads: int,
window_size: List[int],
shift_size: List[int],
mlp_ratio: float = 4.0,
dropout: float = 0.0,
attention_dropout: float = 0.0,
stochastic_depth_prob: float = 0.0,
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
attn_layer: Callable[..., nn.Module] = ShiftedWindowAttentionV2,
):
super().__init__(
dim,
num_heads,
window_size,
shift_size,
mlp_ratio=mlp_ratio,
dropout=dropout,
attention_dropout=attention_dropout,
stochastic_depth_prob=stochastic_depth_prob,
norm_layer=norm_layer,
attn_layer=attn_layer,
)
def forward(self, x: Tensor):
x = x + self.stochastic_depth(self.norm1(self.attn(x)))
x = x + self.stochastic_depth(self.norm2(self.mlp(x)))
return x
class SwinTransformer(nn.Module):
"""
Implements Swin Transformer from the `"Swin Transformer: Hierarchical Vision Transformer using
......@@ -318,10 +512,11 @@ class SwinTransformer(nn.Module):
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
dropout (float): Dropout rate. Default: 0.0.
attention_dropout (float): Attention dropout rate. Default: 0.0.
stochastic_depth_prob (float): Stochastic depth rate. Default: 0.0.
stochastic_depth_prob (float): Stochastic depth rate. Default: 0.1.
num_classes (int): Number of classes for classification head. Default: 1000.
block (nn.Module, optional): SwinTransformer Block. Default: None.
norm_layer (nn.Module, optional): Normalization layer. Default: None.
downsample_layer (nn.Module): Downsample layer (patch merging). Default: PatchMerging.
"""
def __init__(
......@@ -334,10 +529,11 @@ class SwinTransformer(nn.Module):
mlp_ratio: float = 4.0,
dropout: float = 0.0,
attention_dropout: float = 0.0,
stochastic_depth_prob: float = 0.0,
stochastic_depth_prob: float = 0.1,
num_classes: int = 1000,
norm_layer: Optional[Callable[..., nn.Module]] = None,
block: Optional[Callable[..., nn.Module]] = None,
downsample_layer: Callable[..., nn.Module] = PatchMerging,
):
super().__init__()
_log_api_usage_once(self)
......@@ -345,7 +541,6 @@ class SwinTransformer(nn.Module):
if block is None:
block = SwinTransformerBlock
if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-5)
......@@ -387,12 +582,14 @@ class SwinTransformer(nn.Module):
layers.append(nn.Sequential(*stage))
# add patch merging layer
if i_stage < (len(depths) - 1):
layers.append(PatchMerging(dim, norm_layer))
layers.append(downsample_layer(dim, norm_layer))
self.features = nn.Sequential(*layers)
num_features = embed_dim * 2 ** (len(depths) - 1)
self.norm = norm_layer(num_features)
self.permute = Permute([0, 3, 1, 2])
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.flatten = nn.Flatten(1)
self.head = nn.Linear(num_features, num_classes)
for m in self.modules():
......@@ -404,9 +601,9 @@ class SwinTransformer(nn.Module):
def forward(self, x):
x = self.features(x)
x = self.norm(x)
x = x.permute(0, 3, 1, 2)
x = self.permute(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.flatten(x)
x = self.head(x)
return x
......@@ -515,6 +712,75 @@ class Swin_B_Weights(WeightsEnum):
DEFAULT = IMAGENET1K_V1
class Swin_V2_T_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/swin_v2_t-b137f0e2.pth",
transforms=partial(
ImageClassification, crop_size=256, resize_size=260, interpolation=InterpolationMode.BICUBIC
),
meta={
**_COMMON_META,
"num_params": 28351570,
"min_size": (256, 256),
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer-v2",
"_metrics": {
"ImageNet-1K": {
"acc@1": 82.072,
"acc@5": 96.132,
}
},
"_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""",
},
)
DEFAULT = IMAGENET1K_V1
class Swin_V2_S_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/swin_v2_s-637d8ceb.pth",
transforms=partial(
ImageClassification, crop_size=256, resize_size=260, interpolation=InterpolationMode.BICUBIC
),
meta={
**_COMMON_META,
"num_params": 49737442,
"min_size": (256, 256),
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer-v2",
"_metrics": {
"ImageNet-1K": {
"acc@1": 83.712,
"acc@5": 96.816,
}
},
"_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""",
},
)
DEFAULT = IMAGENET1K_V1
class Swin_V2_B_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/swin_v2_b-781e5279.pth",
transforms=partial(
ImageClassification, crop_size=256, resize_size=272, interpolation=InterpolationMode.BICUBIC
),
meta={
**_COMMON_META,
"num_params": 87930848,
"min_size": (256, 256),
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer-v2",
"_metrics": {
"ImageNet-1K": {
"acc@1": 84.112,
"acc@5": 96.864,
}
},
"_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""",
},
)
DEFAULT = IMAGENET1K_V1
@register_model()
def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer:
"""
......@@ -624,3 +890,120 @@ def swin_b(*, weights: Optional[Swin_B_Weights] = None, progress: bool = True, *
progress=progress,
**kwargs,
)
@register_model()
def swin_v2_t(*, weights: Optional[Swin_V2_T_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer:
"""
Constructs a swin_v2_tiny architecture from
`Swin Transformer V2: Scaling Up Capacity and Resolution <https://arxiv.org/pdf/2111.09883>`_.
Args:
weights (:class:`~torchvision.models.Swin_V2_T_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.Swin_V2_T_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py>`_
for more details about this class.
.. autoclass:: torchvision.models.Swin_V2_T_Weights
:members:
"""
weights = Swin_V2_T_Weights.verify(weights)
return _swin_transformer(
patch_size=[4, 4],
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=[8, 8],
stochastic_depth_prob=0.2,
weights=weights,
progress=progress,
block=SwinTransformerBlockV2,
downsample_layer=PatchMergingV2,
**kwargs,
)
@register_model()
def swin_v2_s(*, weights: Optional[Swin_V2_S_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer:
"""
Constructs a swin_v2_small architecture from
`Swin Transformer V2: Scaling Up Capacity and Resolution <https://arxiv.org/pdf/2111.09883>`_.
Args:
weights (:class:`~torchvision.models.Swin_V2_S_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.Swin_V2_S_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py>`_
for more details about this class.
.. autoclass:: torchvision.models.Swin_V2_S_Weights
:members:
"""
weights = Swin_V2_S_Weights.verify(weights)
return _swin_transformer(
patch_size=[4, 4],
embed_dim=96,
depths=[2, 2, 18, 2],
num_heads=[3, 6, 12, 24],
window_size=[8, 8],
stochastic_depth_prob=0.3,
weights=weights,
progress=progress,
block=SwinTransformerBlockV2,
downsample_layer=PatchMergingV2,
**kwargs,
)
@register_model()
def swin_v2_b(*, weights: Optional[Swin_V2_B_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer:
"""
Constructs a swin_v2_base architecture from
`Swin Transformer V2: Scaling Up Capacity and Resolution <https://arxiv.org/pdf/2111.09883>`_.
Args:
weights (:class:`~torchvision.models.Swin_V2_B_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.Swin_V2_B_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py>`_
for more details about this class.
.. autoclass:: torchvision.models.Swin_V2_B_Weights
:members:
"""
weights = Swin_V2_B_Weights.verify(weights)
return _swin_transformer(
patch_size=[4, 4],
embed_dim=128,
depths=[2, 2, 18, 2],
num_heads=[4, 8, 16, 32],
window_size=[8, 8],
stochastic_depth_prob=0.5,
weights=weights,
progress=progress,
block=SwinTransformerBlockV2,
downsample_layer=PatchMergingV2,
**kwargs,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment