Commit d88d8961 authored by eellison's avatar eellison Committed by Francisco Massa
Browse files

Make maskrcnn scriptable (#1407)

* almost working...

* respond to comments

* add empty tensor op, handle different output types in generalized rcnn

* clean ups

* address comments

* more changes

* it's working!

* torchscript bugs

* add script/ eager test

* eval script model

* fix flake

* division import

* py2 compat

* update test, fix arange bug

* import division statement

* fix linter

* fixes

* changes needed for JIT master

* cleanups

* remove imagelist_to

* requested changes

* Make FPN backwards-compatible and torchscript compatible

We remove support for feature channels=0, but support for it was already a bit limited

* Fix ONNX regression
parent b590f8c6
...@@ -2,7 +2,9 @@ from collections import OrderedDict ...@@ -2,7 +2,9 @@ from collections import OrderedDict
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn, Tensor
from torch.jit.annotations import Tuple, List, Dict
class FeaturePyramidNetwork(nn.Module): class FeaturePyramidNetwork(nn.Module):
...@@ -42,14 +44,13 @@ class FeaturePyramidNetwork(nn.Module): ...@@ -42,14 +44,13 @@ class FeaturePyramidNetwork(nn.Module):
>>> ('feat3', torch.Size([1, 5, 8, 8]))] >>> ('feat3', torch.Size([1, 5, 8, 8]))]
""" """
def __init__(self, in_channels_list, out_channels, extra_blocks=None): def __init__(self, in_channels_list, out_channels, extra_blocks=None):
super(FeaturePyramidNetwork, self).__init__() super(FeaturePyramidNetwork, self).__init__()
self.inner_blocks = nn.ModuleList() self.inner_blocks = nn.ModuleList()
self.layer_blocks = nn.ModuleList() self.layer_blocks = nn.ModuleList()
for in_channels in in_channels_list: for in_channels in in_channels_list:
if in_channels == 0: if in_channels == 0:
continue raise ValueError("in_channels=0 is currently not supported")
inner_block_module = nn.Conv2d(in_channels, out_channels, 1) inner_block_module = nn.Conv2d(in_channels, out_channels, 1)
layer_block_module = nn.Conv2d(out_channels, out_channels, 3, padding=1) layer_block_module = nn.Conv2d(out_channels, out_channels, 3, padding=1)
self.inner_blocks.append(inner_block_module) self.inner_blocks.append(inner_block_module)
...@@ -65,7 +66,46 @@ class FeaturePyramidNetwork(nn.Module): ...@@ -65,7 +66,46 @@ class FeaturePyramidNetwork(nn.Module):
assert isinstance(extra_blocks, ExtraFPNBlock) assert isinstance(extra_blocks, ExtraFPNBlock)
self.extra_blocks = extra_blocks self.extra_blocks = extra_blocks
def get_result_from_inner_blocks(self, x, idx):
# type: (Tensor, int)
"""
This is equivalent to self.inner_blocks[idx](x),
but torchscript doesn't support this yet
"""
num_blocks = 0
for m in self.inner_blocks:
num_blocks += 1
if idx < 0:
idx += num_blocks
i = 0
out = x
for module in self.inner_blocks:
if i == idx:
out = module(x)
i += 1
return out
def get_result_from_layer_blocks(self, x, idx):
# type: (Tensor, int)
"""
This is equivalent to self.layer_blocks[idx](x),
but torchscript doesn't support this yet
"""
num_blocks = 0
for m in self.layer_blocks:
num_blocks += 1
if idx < 0:
idx += num_blocks
i = 0
out = x
for module in self.layer_blocks:
if i == idx:
out = module(x)
i += 1
return out
def forward(self, x): def forward(self, x):
# type: (Dict[str, Tensor])
""" """
Computes the FPN for a set of feature maps. Computes the FPN for a set of feature maps.
...@@ -80,19 +120,16 @@ class FeaturePyramidNetwork(nn.Module): ...@@ -80,19 +120,16 @@ class FeaturePyramidNetwork(nn.Module):
names = list(x.keys()) names = list(x.keys())
x = list(x.values()) x = list(x.values())
last_inner = self.inner_blocks[-1](x[-1]) last_inner = self.get_result_from_inner_blocks(x[-1], -1)
results = [] results = []
results.append(self.layer_blocks[-1](last_inner)) results.append(self.get_result_from_layer_blocks(last_inner, -1))
for feature, inner_block, layer_block in zip(
x[:-1][::-1], self.inner_blocks[:-1][::-1], self.layer_blocks[:-1][::-1] for idx in range(len(x) - 2, -1, -1):
): inner_lateral = self.get_result_from_inner_blocks(x[idx], idx)
if not inner_block:
continue
inner_lateral = inner_block(feature)
feat_shape = inner_lateral.shape[-2:] feat_shape = inner_lateral.shape[-2:]
inner_top_down = F.interpolate(last_inner, size=feat_shape, mode="nearest") inner_top_down = F.interpolate(last_inner, size=feat_shape, mode="nearest")
last_inner = inner_lateral + inner_top_down last_inner = inner_lateral + inner_top_down
results.insert(0, layer_block(last_inner)) results.insert(0, self.get_result_from_layer_blocks(last_inner, idx))
if self.extra_blocks is not None: if self.extra_blocks is not None:
results, names = self.extra_blocks(results, x, names) results, names = self.extra_blocks(results, x, names)
...@@ -127,6 +164,7 @@ class LastLevelMaxPool(ExtraFPNBlock): ...@@ -127,6 +164,7 @@ class LastLevelMaxPool(ExtraFPNBlock):
Applies a max_pool2d on top of the last feature map Applies a max_pool2d on top of the last feature map
""" """
def forward(self, x, y, names): def forward(self, x, y, names):
# type: (List[Tensor], List[Tensor], List[str])
names.append("pool") names.append("pool")
x.append(F.max_pool2d(x[-1], 1, 2, 0)) x.append(F.max_pool2d(x[-1], 1, 2, 0))
return x, names return x, names
......
from __future__ import division from __future__ import division
from collections import OrderedDict
from torch.jit.annotations import Optional, List
from torch import Tensor
""" """
helper class that supports empty tensors on some nn functions. helper class that supports empty tensors on some nn functions.
...@@ -12,40 +15,9 @@ is implemented ...@@ -12,40 +15,9 @@ is implemented
import math import math
import torch import torch
from torch.nn.modules.utils import _ntuple from torchvision.ops import _new_empty_tensor
from torch.nn import Module, Conv2d
import torch.nn.functional as F
class _NewEmptyTensorOp(torch.autograd.Function):
@staticmethod
def forward(ctx, x, new_shape):
ctx.shape = x.shape
return x.new_empty(new_shape)
@staticmethod
def backward(ctx, grad):
shape = ctx.shape
return _NewEmptyTensorOp.apply(grad, shape), None
class Conv2d(torch.nn.Conv2d):
"""
Equivalent to nn.Conv2d, but with support for empty batch sizes.
This will eventually be supported natively by PyTorch, and this
class can go away.
"""
def forward(self, x):
if x.numel() > 0:
return super(Conv2d, self).forward(x)
# get output shape
output_shape = [
(i + 2 * p - (di * (k - 1) + 1)) // d + 1
for i, p, di, k, d in zip(
x.shape[-2:], self.padding, self.dilation, self.kernel_size, self.stride
)
]
output_shape = [x.shape[0], self.weight.shape[0]] + output_shape
return _NewEmptyTensorOp.apply(x, output_shape)
class ConvTranspose2d(torch.nn.ConvTranspose2d): class ConvTranspose2d(torch.nn.ConvTranspose2d):
...@@ -56,22 +28,33 @@ class ConvTranspose2d(torch.nn.ConvTranspose2d): ...@@ -56,22 +28,33 @@ class ConvTranspose2d(torch.nn.ConvTranspose2d):
""" """
def forward(self, x): def forward(self, x):
if x.numel() > 0: if x.numel() > 0:
return super(ConvTranspose2d, self).forward(x) return self.super_forward(x)
# get output shape # get output shape
output_shape = [ output_shape = [
(i - 1) * d - 2 * p + (di * (k - 1) + 1) + op (i - 1) * d - 2 * p + (di * (k - 1) + 1) + op
for i, p, di, k, d, op in zip( for i, p, di, k, d, op in zip(
x.shape[-2:], x.shape[-2:],
self.padding, list(self.padding),
self.dilation, list(self.dilation),
self.kernel_size, list(self.kernel_size),
self.stride, list(self.stride),
self.output_padding, list(self.output_padding),
) )
] ]
output_shape = [x.shape[0], self.bias.shape[0]] + output_shape output_shape = [x.shape[0], self.bias.shape[0]] + output_shape
return _NewEmptyTensorOp.apply(x, output_shape) return _new_empty_tensor(x, output_shape)
def super_forward(self, input, output_size=None):
# type: (Tensor, Optional[List[int]]) -> Tensor
if self.padding_mode != 'zeros':
raise ValueError('Only `zeros` padding mode is supported for ConvTranspose2d')
output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size)
return F.conv_transpose2d(
input, self.weight, self.bias, self.stride, self.padding,
output_padding, self.groups, self.dilation)
class BatchNorm2d(torch.nn.BatchNorm2d): class BatchNorm2d(torch.nn.BatchNorm2d):
...@@ -85,12 +68,39 @@ class BatchNorm2d(torch.nn.BatchNorm2d): ...@@ -85,12 +68,39 @@ class BatchNorm2d(torch.nn.BatchNorm2d):
return super(BatchNorm2d, self).forward(x) return super(BatchNorm2d, self).forward(x)
# get output shape # get output shape
output_shape = x.shape output_shape = x.shape
return _NewEmptyTensorOp.apply(x, output_shape) return _new_empty_tensor(x, output_shape)
def _check_size_scale_factor(dim, size, scale_factor):
# type: (int, Optional[List[int]], Optional[float]) -> None
if size is None and scale_factor is None:
raise ValueError("either size or scale_factor should be defined")
if size is not None and scale_factor is not None:
raise ValueError("only one of size or scale_factor should be defined")
if not (scale_factor is not None and len(scale_factor) != dim):
raise ValueError(
"scale_factor shape must match input shape. "
"Input is {}D, scale_factor size is {}".format(dim, len(scale_factor))
)
def interpolate( def _output_size(dim, input, size, scale_factor):
input, size=None, scale_factor=None, mode="nearest", align_corners=None # type: (int, Tensor, Optional[List[int]], Optional[float]) -> List[int]
): assert dim == 2
_check_size_scale_factor(dim, size, scale_factor)
if size is not None:
return size
# if dim is not 2 or scale_factor is iterable use _ntuple instead of concat
assert scale_factor is not None and isinstance(scale_factor, (int, float))
scale_factors = [scale_factor, scale_factor]
# math.floor might return float in py2.7
return [
int(math.floor(input.size(i + 2) * scale_factors[i])) for i in range(dim)
]
def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
""" """
Equivalent to nn.functional.interpolate, but with support for empty batch sizes. Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
This will eventually be supported natively by PyTorch, and this This will eventually be supported natively by PyTorch, and this
...@@ -101,34 +111,9 @@ def interpolate( ...@@ -101,34 +111,9 @@ def interpolate(
input, size, scale_factor, mode, align_corners input, size, scale_factor, mode, align_corners
) )
def _check_size_scale_factor(dim): output_shape = _output_size(2, input, size, scale_factor)
if size is None and scale_factor is None:
raise ValueError("either size or scale_factor should be defined")
if size is not None and scale_factor is not None:
raise ValueError("only one of size or scale_factor should be defined")
if (
scale_factor is not None and
isinstance(scale_factor, tuple) and
len(scale_factor) != dim
):
raise ValueError(
"scale_factor shape must match input shape. "
"Input is {}D, scale_factor size is {}".format(dim, len(scale_factor))
)
def _output_size(dim):
_check_size_scale_factor(dim)
if size is not None:
return size
scale_factors = _ntuple(dim)(scale_factor)
# math.floor might return float in py2.7
return [
int(math.floor(input.size(i + 2) * scale_factors[i])) for i in range(dim)
]
output_shape = tuple(_output_size(2))
output_shape = input.shape[:-2] + output_shape output_shape = input.shape[:-2] + output_shape
return _NewEmptyTensorOp.apply(input, output_shape) return _new_empty_tensor(input, output_shape)
# This is not in nn # This is not in nn
......
import torch
from torch.jit.annotations import List
from torch import Tensor
def _new_empty_tensor(x, shape):
# type: (Tensor, List[int]) -> Tensor
"""
Arguments:
input (Tensor): input tensor
shape List[int]: the new empty tensor shape
Returns:
output (Tensor)
"""
return torch.ops.torchvision._new_empty_tensor_op(x, shape)
from __future__ import division
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn, Tensor
import torchvision
from torchvision.ops import roi_align from torchvision.ops import roi_align
from torchvision.ops.boxes import box_area from torchvision.ops.boxes import box_area
from torch.jit.annotations import Optional, List, Dict, Tuple
import torchvision
# copying result_idx_in_level to a specific index in result[] # copying result_idx_in_level to a specific index in result[]
# is not supported by ONNX tracing yet. # is not supported by ONNX tracing yet.
# _onnx_merge_levels() is an implementation supported by ONNX # _onnx_merge_levels() is an implementation supported by ONNX
# that merges the levels to the right indices # that merges the levels to the right indices
@torch.jit.unused
def _onnx_merge_levels(levels, unmerged_results): def _onnx_merge_levels(levels, unmerged_results):
# type: (Tensor, List[Tensor]) -> Tensor
first_result = unmerged_results[0] first_result = unmerged_results[0]
dtype, device = first_result.dtype, first_result.device dtype, device = first_result.dtype, first_result.device
res = torch.zeros((levels.size(0), first_result.size(1), res = torch.zeros((levels.size(0), first_result.size(1),
...@@ -28,6 +33,13 @@ def _onnx_merge_levels(levels, unmerged_results): ...@@ -28,6 +33,13 @@ def _onnx_merge_levels(levels, unmerged_results):
return res return res
# TODO: (eellison) T54974082 https://github.com/pytorch/pytorch/issues/26744/pytorch/issues/26744
def initLevelMapper(k_min, k_max, canonical_scale=224, canonical_level=4, eps=1e-6):
# type: (int, int, int, int, float)
return LevelMapper(k_min, k_max, canonical_scale, canonical_level, eps)
@torch.jit.script
class LevelMapper(object): class LevelMapper(object):
"""Determine which FPN level each RoI in a set of RoIs should map to based """Determine which FPN level each RoI in a set of RoIs should map to based
on the heuristic in the FPN paper. on the heuristic in the FPN paper.
...@@ -41,6 +53,7 @@ class LevelMapper(object): ...@@ -41,6 +53,7 @@ class LevelMapper(object):
""" """
def __init__(self, k_min, k_max, canonical_scale=224, canonical_level=4, eps=1e-6): def __init__(self, k_min, k_max, canonical_scale=224, canonical_level=4, eps=1e-6):
# type: (int, int, int, int, float)
self.k_min = k_min self.k_min = k_min
self.k_max = k_max self.k_max = k_max
self.s0 = canonical_scale self.s0 = canonical_scale
...@@ -48,6 +61,7 @@ class LevelMapper(object): ...@@ -48,6 +61,7 @@ class LevelMapper(object):
self.eps = eps self.eps = eps
def __call__(self, boxlists): def __call__(self, boxlists):
# type: (List[Tensor])
""" """
Arguments: Arguments:
boxlists (list[BoxList]) boxlists (list[BoxList])
...@@ -90,6 +104,11 @@ class MultiScaleRoIAlign(nn.Module): ...@@ -90,6 +104,11 @@ class MultiScaleRoIAlign(nn.Module):
""" """
__annotations__ = {
'scales': Optional[List[float]],
'map_levels': Optional[LevelMapper]
}
def __init__(self, featmap_names, output_size, sampling_ratio): def __init__(self, featmap_names, output_size, sampling_ratio):
super(MultiScaleRoIAlign, self).__init__() super(MultiScaleRoIAlign, self).__init__()
if isinstance(output_size, int): if isinstance(output_size, int):
...@@ -101,11 +120,12 @@ class MultiScaleRoIAlign(nn.Module): ...@@ -101,11 +120,12 @@ class MultiScaleRoIAlign(nn.Module):
self.map_levels = None self.map_levels = None
def convert_to_roi_format(self, boxes): def convert_to_roi_format(self, boxes):
# type: (List[Tensor])
concat_boxes = torch.cat(boxes, dim=0) concat_boxes = torch.cat(boxes, dim=0)
device, dtype = concat_boxes.device, concat_boxes.dtype device, dtype = concat_boxes.device, concat_boxes.dtype
ids = torch.cat( ids = torch.cat(
[ [
torch.full_like(b[:, :1], i, dtype=dtype, device=device) torch.full_like(b[:, :1], i, dtype=dtype, layout=torch.strided, device=device)
for i, b in enumerate(boxes) for i, b in enumerate(boxes)
], ],
dim=0, dim=0,
...@@ -114,27 +134,37 @@ class MultiScaleRoIAlign(nn.Module): ...@@ -114,27 +134,37 @@ class MultiScaleRoIAlign(nn.Module):
return rois return rois
def infer_scale(self, feature, original_size): def infer_scale(self, feature, original_size):
# type: (Tensor, List[int])
# assumption: the scale is of the form 2 ** (-k), with k integer # assumption: the scale is of the form 2 ** (-k), with k integer
size = feature.shape[-2:] size = feature.shape[-2:]
possible_scales = [] possible_scales = torch.jit.annotate(List[float], [])
for s1, s2 in zip(size, original_size): for s1, s2 in zip(size, original_size):
approx_scale = float(s1) / s2 approx_scale = float(s1) / s2
scale = 2 ** torch.tensor(approx_scale).log2().round().item() scale = 2 ** float(torch.tensor(approx_scale).log2().round())
possible_scales.append(scale) possible_scales.append(scale)
assert possible_scales[0] == possible_scales[1] assert possible_scales[0] == possible_scales[1]
return possible_scales[0] return possible_scales[0]
def setup_scales(self, features, image_shapes): def setup_scales(self, features, image_shapes):
original_input_shape = tuple(max(s) for s in zip(*image_shapes)) # type: (List[Tensor], List[Tuple[int, int]])
assert len(image_shapes) != 0
max_x = 0
max_y = 0
for shape in image_shapes:
max_x = max(shape[0], max_x)
max_y = max(shape[1], max_y)
original_input_shape = (max_x, max_y)
scales = [self.infer_scale(feat, original_input_shape) for feat in features] scales = [self.infer_scale(feat, original_input_shape) for feat in features]
# get the levels in the feature map by leveraging the fact that the network always # get the levels in the feature map by leveraging the fact that the network always
# downsamples by a factor of 2 at each level. # downsamples by a factor of 2 at each level.
lvl_min = -torch.log2(torch.tensor(scales[0], dtype=torch.float32)).item() lvl_min = -torch.log2(torch.tensor(scales[0], dtype=torch.float32)).item()
lvl_max = -torch.log2(torch.tensor(scales[-1], dtype=torch.float32)).item() lvl_max = -torch.log2(torch.tensor(scales[-1], dtype=torch.float32)).item()
self.scales = scales self.scales = scales
self.map_levels = LevelMapper(lvl_min, lvl_max) self.map_levels = initLevelMapper(int(lvl_min), int(lvl_max))
def forward(self, x, boxes, image_shapes): def forward(self, x, boxes, image_shapes):
# type: (Dict[str, Tensor], List[Tensor], List[Tuple[int, int]])
""" """
Arguments: Arguments:
x (OrderedDict[Tensor]): feature maps for each level. They are assumed to have x (OrderedDict[Tensor]): feature maps for each level. They are assumed to have
...@@ -148,34 +178,43 @@ class MultiScaleRoIAlign(nn.Module): ...@@ -148,34 +178,43 @@ class MultiScaleRoIAlign(nn.Module):
Returns: Returns:
result (Tensor) result (Tensor)
""" """
x = [v for k, v in x.items() if k in self.featmap_names] x_filtered = []
num_levels = len(x) for k, v in x.items():
if k in self.featmap_names:
x_filtered.append(v)
num_levels = len(x_filtered)
rois = self.convert_to_roi_format(boxes) rois = self.convert_to_roi_format(boxes)
if self.scales is None: if self.scales is None:
self.setup_scales(x, image_shapes) self.setup_scales(x_filtered, image_shapes)
scales = self.scales
assert scales is not None
if num_levels == 1: if num_levels == 1:
return roi_align( return roi_align(
x[0], rois, x_filtered[0], rois,
output_size=self.output_size, output_size=self.output_size,
spatial_scale=self.scales[0], spatial_scale=scales[0],
sampling_ratio=self.sampling_ratio sampling_ratio=self.sampling_ratio
) )
levels = self.map_levels(boxes) mapper = self.map_levels
assert mapper is not None
levels = mapper(boxes)
num_rois = len(rois) num_rois = len(rois)
num_channels = x[0].shape[1] num_channels = x_filtered[0].shape[1]
dtype, device = x[0].dtype, x[0].device dtype, device = x_filtered[0].dtype, x_filtered[0].device
result = torch.zeros( result = torch.zeros(
(num_rois, num_channels,) + self.output_size, (num_rois, num_channels,) + self.output_size,
dtype=dtype, dtype=dtype,
device=device, device=device,
) )
results = [] tracing_results = []
for level, (per_level_feature, scale) in enumerate(zip(x, self.scales)): for level, (per_level_feature, scale) in enumerate(zip(x_filtered, scales)):
idx_in_level = torch.nonzero(levels == level).squeeze(1) idx_in_level = torch.nonzero(levels == level).squeeze(1)
rois_per_level = rois[idx_in_level] rois_per_level = rois[idx_in_level]
...@@ -185,10 +224,11 @@ class MultiScaleRoIAlign(nn.Module): ...@@ -185,10 +224,11 @@ class MultiScaleRoIAlign(nn.Module):
spatial_scale=scale, sampling_ratio=self.sampling_ratio) spatial_scale=scale, sampling_ratio=self.sampling_ratio)
if torchvision._is_tracing(): if torchvision._is_tracing():
results.append(result_idx_in_level.to(dtype)) tracing_results.append(result_idx_in_level.to(dtype))
else: else:
result[idx_in_level] = result_idx_in_level result[idx_in_level] = result_idx_in_level
if torchvision._is_tracing(): if torchvision._is_tracing():
result = _onnx_merge_levels(levels, results) result = _onnx_merge_levels(levels, tracing_results)
return result return result
...@@ -2,13 +2,13 @@ import torch ...@@ -2,13 +2,13 @@ import torch
from torch import nn, Tensor from torch import nn, Tensor
from torch.nn.modules.utils import _pair from torch.nn.modules.utils import _pair
from torch.jit.annotations import List from torch.jit.annotations import List, BroadcastingList2
from ._utils import convert_boxes_to_roi_format from ._utils import convert_boxes_to_roi_format
def roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1): def roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1):
# type: (Tensor, Tensor, int, float, int) -> Tensor # type: (Tensor, Tensor, BroadcastingList2[int], float, int) -> Tensor
""" """
Performs Region of Interest (RoI) Align operator described in Mask R-CNN Performs Region of Interest (RoI) Align operator described in Mask R-CNN
......
...@@ -2,13 +2,13 @@ import torch ...@@ -2,13 +2,13 @@ import torch
from torch import nn, Tensor from torch import nn, Tensor
from torch.nn.modules.utils import _pair from torch.nn.modules.utils import _pair
from torch.jit.annotations import List from torch.jit.annotations import List, BroadcastingList2
from ._utils import convert_boxes_to_roi_format from ._utils import convert_boxes_to_roi_format
def roi_pool(input, boxes, output_size, spatial_scale=1.0): def roi_pool(input, boxes, output_size, spatial_scale=1.0):
# type: (Tensor, Tensor, int, float) -> Tensor # type: (Tensor, Tensor, BroadcastingList2[int], float) -> Tensor
""" """
Performs Region of Interest (RoI) Pool operator described in Fast R-CNN Performs Region of Interest (RoI) Pool operator described in Fast R-CNN
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment