Unverified Commit e288f6ca authored by Hu Ye's avatar Hu Ye Committed by GitHub
Browse files

Adding Swin Transformer architecture (#5491)



* add swin transformer

* Update swin_transformer.py

* Update swin_transformer.py

* fix lint

* fix lint

* refactor code

* add swin_transformer

* Update swin_transformer.py

* fix bug

* refactor code

* fix lint

* update init_weights

* move shift_window into attention

* refactor code

* fix bug

* Update swin_transformer.py

* Update swin_transformer.py

* fix lint

* add patch_merge

* fix bug

* Update swin_transformer.py

* Update swin_transformer.py

* Update swin_transformer.py

* refactor code

* Update swin_transformer.py

* refactor code

* fix lint

* refactor code

* add swin_tiny

* add swin_tiny.pkl

* fix lint

* Delete ModelTester.test_swin_tiny_expect.pkl

* add swin_tiny

* add

* add Optional to bias

* update init weights

* update init_weights and add no weight decay

* add no weight decay

* add set_weight_decay

* add set_weight_decay

* fix lint

* fix lint

* add lr_cos_min

* add other swin models

* Update torchvision/models/swin_transformer.py
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>

* refactor doc

* Update utils.py

* Update train.py

* Update train.py

* Update swin_transformer.py

* update model builder

* fix lint

* add

* Update torchvision/models/swin_transformer.py
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>

* Update torchvision/models/swin_transformer.py
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>

* update other model

* simplify the model name just like ViT

* add lr_cos_min

* fix lint

* fix lint

* Update swin_transformer.py

* Update swin_transformer.py

* Update swin_transformer.py

* Delete ModelTester.test_swin_tiny_expect.pkl

* add swin_t

* refactor code

* Update train.py

* add swin_s

* ignore a error of mypy

* Update swin_transformer.py

* fix lint

* add swin_b

* add swin_l

* refactor code

* Update train.py

* move relative_position_bias to __init__

* fix formatting

* Revert "fix formatting"

This reverts commit 41faba232668f7ac4273a0cf632c0d0130c7ce9c.

* Revert "move relative_position_bias to __init__"

This reverts commit f0615440bf18617dc0e5dc4839bd5ed27e5ed010.

* refactor code

* Remove deprecated meta-data from `_COMMON_META`

* fix linter

* add pretrained weights for swin_t

* fix format

* apply ufmt

* add documentation

* update references README

* adding new style docs

* update pre-trained weights values

* remove other variants

* fix typo

* Remove expect for the variants not yet supported
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
Co-authored-by: default avatarJoao Gomes <jdsgomes@fb.com>
parent bb1ab475
......@@ -42,6 +42,7 @@ architectures for image classification:
- `RegNet`_
- `VisionTransformer`_
- `ConvNeXt`_
- `SwinTransformer`_
You can construct a model with random weights by calling its constructor:
......@@ -97,6 +98,7 @@ You can construct a model with random weights by calling its constructor:
convnext_small = models.convnext_small()
convnext_base = models.convnext_base()
convnext_large = models.convnext_large()
swin_t = models.swin_t()
We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`.
......@@ -219,6 +221,7 @@ convnext_tiny 82.520 96.146
convnext_small 83.616 96.650
convnext_base 84.062 96.870
convnext_large 84.414 96.976
swin_t 81.358 95.526
================================ ============= =============
......@@ -238,6 +241,7 @@ convnext_large 84.414 96.976
.. _RegNet: https://arxiv.org/abs/2003.13678
.. _VisionTransformer: https://arxiv.org/abs/2010.11929
.. _ConvNeXt: https://arxiv.org/abs/2201.03545
.. _SwinTransformer: https://arxiv.org/abs/2103.14030
.. currentmodule:: torchvision.models
......@@ -450,6 +454,15 @@ ConvNeXt
convnext_base
convnext_large
SwinTransformer
--------
.. autosummary::
:toctree: generated/
:template: function.rst
swin_t
Quantized Models
----------------
......
SwinTransformer
===============
.. currentmodule:: torchvision.models
The SwinTransformer model is based on the `Swin Transformer: Hierarchical Vision
Transformer using Shifted Windows <https://arxiv.org/abs/2103.14030>`__
paper.
Model builders
--------------
The following model builders can be used to instanciate an SwinTransformer model.
`swin_t` can be instantiated with pre-trained weights and all others without.
All the model builders internally rely on the ``torchvision.models.swin_transformer.SwinTransformer``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py>`_ for
more details about this class.
.. autosummary::
:toctree: generated/
:template: function.rst
swin_t
......@@ -46,6 +46,7 @@ weights:
models/resnet
models/resnext
models/squeezenet
models/swin_transformer
models/vgg
models/vision_transformer
......
......@@ -224,6 +224,18 @@ Note that the above command corresponds to training on a single node with 8 GPUs
For generatring the pre-trained weights, we trained with 2 nodes, each with 8 GPUs (for a total of 16 GPUs),
and `--batch_size 64`.
### SwinTransformer
```
torchrun --nproc_per_node=8 train.py\
--model swin_t --epochs 300 --batch-size 128 --opt adamw --lr 0.001 --weight-decay 0.05 --norm-weight-decay 0.0\
--bias-weight-decay 0.0 --transformer-embedding-decay 0.0 --lr-scheduler cosineannealinglr --lr-min 0.00001 --lr-warmup-method linear\
--lr-warmup-epochs 20 --lr-warmup-decay 0.01 --amp --label-smoothing 0.1 --mixup-alpha 0.8\
--clip-grad-norm 5.0 --cutmix-alpha 1.0 --random-erase 0.25 --interpolation bicubic --auto-augment ra
```
Note that `--val-resize-size` was optimized in a post-training step, see their `Weights` entry for the exact value.
## Mixed precision training
Automatic Mixed Precision (AMP) training on GPU for Pytorch can be enabled with the [torch.cuda.amp](https://pytorch.org/docs/stable/amp.html?highlight=amp#module-torch.cuda.amp).
......
......@@ -233,7 +233,7 @@ def main(args):
if args.bias_weight_decay is not None:
custom_keys_weight_decay.append(("bias", args.bias_weight_decay))
if args.transformer_embedding_decay is not None:
for key in ["class_token", "position_embedding", "relative_position_bias"]:
for key in ["class_token", "position_embedding", "relative_position_bias_table"]:
custom_keys_weight_decay.append((key, args.transformer_embedding_decay))
parameters = utils.set_weight_decay(
model,
......@@ -267,7 +267,7 @@ def main(args):
main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
elif args.lr_scheduler == "cosineannealinglr":
main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=args.epochs - args.lr_warmup_epochs
optimizer, T_max=args.epochs - args.lr_warmup_epochs, eta_min=args.lr_min
)
elif args.lr_scheduler == "exponentiallr":
main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma)
......@@ -424,6 +424,7 @@ def get_args_parser(add_help=True):
parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr")
parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
parser.add_argument("--lr-min", default=0.0, type=float, help="minimum lr of lr schedule (default: 0.0)")
parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
......
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
......@@ -12,6 +12,7 @@ from .shufflenetv2 import *
from .squeezenet import *
from .vgg import *
from .vision_transformer import *
from .swin_transformer import *
from . import detection
from . import optical_flow
from . import quantization
......
from functools import partial
from typing import Optional, Callable, List, Any
import torch
import torch.nn.functional as F
from torch import nn, Tensor
from ..ops.stochastic_depth import StochasticDepth
from ..transforms._presets import ImageClassification, InterpolationMode
from ..utils import _log_api_usage_once
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _ovewrite_named_param
from .convnext import Permute
from .vision_transformer import MLPBlock
__all__ = [
"SwinTransformer",
"Swin_T_Weights",
"swin_t",
]
class PatchMerging(nn.Module):
"""Patch Merging Layer.
Args:
dim (int): Number of input channels.
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
"""
def __init__(self, dim: int, norm_layer: Callable[..., nn.Module] = nn.LayerNorm):
super().__init__()
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x: Tensor):
B, H, W, C = x.shape
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
x = x.view(B, H // 2, W // 2, 2 * C)
return x
def shifted_window_attention(
input: Tensor,
qkv_weight: Tensor,
proj_weight: Tensor,
relative_position_bias: Tensor,
window_size: int,
num_heads: int,
shift_size: int = 0,
attention_dropout: float = 0.0,
dropout: float = 0.0,
qkv_bias: Optional[Tensor] = None,
proj_bias: Optional[Tensor] = None,
):
"""
Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
input (Tensor[N, H, W, C]): The input tensor or 4-dimensions.
qkv_weight (Tensor[in_dim, out_dim]): The weight tensor of query, key, value.
proj_weight (Tensor[out_dim, out_dim]): The weight tensor of projection.
relative_position_bias (Tensor): The learned relative position bias added to attention.
window_size (int): Window size.
num_heads (int): Number of attention heads.
shift_size (int): Shift size for shifted window attention. Default: 0.
attention_dropout (float): Dropout ratio of attention weight. Default: 0.0.
dropout (float): Dropout ratio of output. Default: 0.0.
qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None.
proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None.
Returns:
Tensor[N, H, W, C]: The output tensor after shifted window attention.
"""
B, H, W, C = input.shape
# pad feature maps to multiples of window size
pad_r = (window_size - W % window_size) % window_size
pad_b = (window_size - H % window_size) % window_size
x = F.pad(input, (0, 0, 0, pad_r, 0, pad_b))
_, pad_H, pad_W, _ = x.shape
# If window size is larger than feature size, there is no need to shift window.
if window_size == min(pad_H, pad_W):
shift_size = 0
# cyclic shift
if shift_size > 0:
x = torch.roll(x, shifts=(-shift_size, -shift_size), dims=(1, 2))
# partition windows
num_windows = (pad_H // window_size) * (pad_W // window_size)
x = x.view(B, pad_H // window_size, window_size, pad_W // window_size, window_size, C)
x = x.permute(0, 1, 3, 2, 4, 5).reshape(B * num_windows, window_size * window_size, C) # B*nW, Ws*Ws, C
# multi-head attention
qkv = F.linear(x, qkv_weight, qkv_bias)
qkv = qkv.reshape(x.size(0), x.size(1), 3, num_heads, C // num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * (C // num_heads) ** -0.5
attn = q.matmul(k.transpose(-2, -1))
# add relative position bias
attn = attn + relative_position_bias
if shift_size > 0:
# generate attention mask
attn_mask = x.new_zeros((pad_H, pad_W))
slices = ((0, -window_size), (-window_size, -shift_size), (-shift_size, None))
count = 0
for h in slices:
for w in slices:
attn_mask[h[0] : h[1], w[0] : w[1]] = count
count += 1
attn_mask = attn_mask.view(pad_H // window_size, window_size, pad_W // window_size, window_size)
attn_mask = attn_mask.permute(0, 2, 1, 3).reshape(num_windows, window_size * window_size)
attn_mask = attn_mask.unsqueeze(1) - attn_mask.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
attn = attn.view(x.size(0) // num_windows, num_windows, num_heads, x.size(1), x.size(1))
attn = attn + attn_mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, num_heads, x.size(1), x.size(1))
attn = F.softmax(attn, dim=-1)
attn = F.dropout(attn, p=attention_dropout)
x = attn.matmul(v).transpose(1, 2).reshape(x.size(0), x.size(1), C)
x = F.linear(x, proj_weight, proj_bias)
x = F.dropout(x, p=dropout)
# reverse windows
x = x.view(B, pad_H // window_size, pad_W // window_size, window_size, window_size, C)
x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, pad_H, pad_W, C)
# reverse cyclic shift
if shift_size > 0:
x = torch.roll(x, shifts=(shift_size, shift_size), dims=(1, 2))
# unpad features
x = x[:, :H, :W, :].contiguous()
return x
torch.fx.wrap("shifted_window_attention")
class ShiftedWindowAttention(nn.Module):
"""
See :func:`shifted_window_attention`.
"""
def __init__(
self,
dim: int,
window_size: int,
shift_size: int,
num_heads: int,
qkv_bias: bool = True,
proj_bias: bool = True,
attention_dropout: float = 0.0,
dropout: float = 0.0,
):
super().__init__()
self.window_size = window_size
self.shift_size = shift_size
self.num_heads = num_heads
self.attention_dropout = attention_dropout
self.dropout = dropout
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim, bias=proj_bias)
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads)
) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size)
coords_w = torch.arange(self.window_size)
coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij")) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size - 1
relative_coords[:, :, 0] *= 2 * self.window_size - 1
relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww*Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)
def forward(self, x: Tensor):
relative_position_bias = self.relative_position_bias_table[self.relative_position_index] # type: ignore[index]
relative_position_bias = relative_position_bias.view(
self.window_size * self.window_size, self.window_size * self.window_size, -1
)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0)
return shifted_window_attention(
x,
self.qkv.weight,
self.proj.weight,
relative_position_bias,
self.window_size,
self.num_heads,
shift_size=self.shift_size,
attention_dropout=self.attention_dropout,
dropout=self.dropout,
qkv_bias=self.qkv.bias,
proj_bias=self.proj.bias,
)
class SwinTransformerBlock(nn.Module):
"""
Swin Transformer Block.
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (int): Window size. Default: 7.
shift_size (int): Shift size for shifted window attention. Default: 0.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
dropout (float): Dropout rate. Default: 0.0.
attention_dropout (float): Attention dropout rate. Default: 0.0.
stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0.
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
"""
def __init__(
self,
dim: int,
num_heads: int,
window_size: int = 7,
shift_size: int = 0,
mlp_ratio: float = 4.0,
dropout: float = 0.0,
attention_dropout: float = 0.0,
stochastic_depth_prob: float = 0.0,
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = ShiftedWindowAttention(
dim,
window_size,
shift_size,
num_heads,
attention_dropout=attention_dropout,
dropout=dropout,
)
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
self.norm2 = norm_layer(dim)
self.mlp = MLPBlock(dim, int(dim * mlp_ratio), dropout)
def forward(self, x: Tensor):
x = x + self.stochastic_depth(self.attn(self.norm1(x)))
x = x + self.stochastic_depth(self.mlp(self.norm2(x)))
return x
class SwinTransformer(nn.Module):
"""
Implements Swin Transformer from the `"Swin Transformer: Hierarchical Vision Transformer using
Shifted Windows" <https://arxiv.org/pdf/2103.14030>`_ paper.
Args:
patch_size (int): Patch size.
embed_dim (int): Patch embedding dimension.
depths (List(int)): Depth of each Swin Transformer layer.
num_heads (List(int)): Number of attention heads in different layers.
window_size (int): Window size. Default: 7.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
dropout (float): Dropout rate. Default: 0.0.
attention_dropout (float): Attention dropout rate. Default: 0.0.
stochastic_depth_prob (float): Stochastic depth rate. Default: 0.0.
num_classes (int): Number of classes for classification head. Default: 1000.
block (nn.Module, optional): SwinTransformer Block. Default: None.
norm_layer (nn.Module, optional): Normalization layer. Default: None.
"""
def __init__(
self,
patch_size: int,
embed_dim: int,
depths: List[int],
num_heads: List[int],
window_size: int = 7,
mlp_ratio: float = 4.0,
dropout: float = 0.0,
attention_dropout: float = 0.0,
stochastic_depth_prob: float = 0.0,
num_classes: int = 1000,
norm_layer: Optional[Callable[..., nn.Module]] = None,
block: Optional[Callable[..., nn.Module]] = None,
):
super().__init__()
_log_api_usage_once(self)
self.num_classes = num_classes
if block is None:
block = SwinTransformerBlock
if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-5)
layers: List[nn.Module] = []
# split image into non-overlapping patches
layers.append(
nn.Sequential(
nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size),
Permute([0, 2, 3, 1]),
norm_layer(embed_dim),
)
)
total_stage_blocks = sum(depths)
stage_block_id = 0
# build SwinTransformer blocks
for i_stage in range(len(depths)):
stage: List[nn.Module] = []
dim = embed_dim * 2 ** i_stage
for i_layer in range(depths[i_stage]):
# adjust stochastic depth probability based on the depth of the stage block
sd_prob = stochastic_depth_prob * float(stage_block_id) / (total_stage_blocks - 1)
stage.append(
block(
dim,
num_heads[i_stage],
window_size=window_size,
shift_size=0 if i_layer % 2 == 0 else window_size // 2,
mlp_ratio=mlp_ratio,
dropout=dropout,
attention_dropout=attention_dropout,
stochastic_depth_prob=sd_prob,
norm_layer=norm_layer,
)
)
stage_block_id += 1
layers.append(nn.Sequential(*stage))
# add patch merging layer
if i_stage < (len(depths) - 1):
layers.append(PatchMerging(dim, norm_layer))
self.features = nn.Sequential(*layers)
num_features = embed_dim * 2 ** (len(depths) - 1)
self.norm = norm_layer(num_features)
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.head = nn.Linear(num_features, num_classes)
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, x):
x = self.features(x)
x = self.norm(x)
x = x.permute(0, 3, 1, 2)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.head(x)
return x
def _swin_transformer(
patch_size: int,
embed_dim: int,
depths: List[int],
num_heads: List[int],
window_size: int,
stochastic_depth_prob: float,
weights: Optional[WeightsEnum],
progress: bool,
**kwargs: Any,
) -> SwinTransformer:
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = SwinTransformer(
patch_size=patch_size,
embed_dim=embed_dim,
depths=depths,
num_heads=num_heads,
window_size=window_size,
stochastic_depth_prob=stochastic_depth_prob,
**kwargs,
)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
_COMMON_META = {
"categories": _IMAGENET_CATEGORIES,
}
class Swin_T_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/swin_t-81486767.pth",
transforms=partial(
ImageClassification, crop_size=224, resize_size=238, interpolation=InterpolationMode.BICUBIC
),
meta={
**_COMMON_META,
"num_params": 28288354,
"min_size": (224, 224),
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swin_t",
"metrics": {
"acc@1": 81.358,
"acc@5": 95.526,
},
},
)
DEFAULT = IMAGENET1K_V1
def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer:
"""
Constructs a swin_tiny architecture from
`Swin Transformer: Hierarchical Vision Transformer using Shifted Windows <https://arxiv.org/pdf/2103.14030>`_.
Args:
weights (:class:`~torchvision.models.Swin_T_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.Swin_T_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py>`_
for more details about this class.
.. autoclass:: torchvision.models.Swin_T_Weights
:members:
"""
weights = Swin_T_Weights.verify(weights)
return _swin_transformer(
patch_size=4,
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
stochastic_depth_prob=0.2,
weights=weights,
progress=progress,
**kwargs,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment