Unverified Commit f14682a8 authored by YosuaMichael's avatar YosuaMichael Committed by GitHub
Browse files

Generalize ConvNormActivation function to accept tuple for some parameters (#6251)

* Make ConvNormActivation function accept tuple for kernel_size, stride, padding, and dilation

* Fix the method to get the conv_dim

* Simplify if-elif logic
parent 54160313
import warnings import warnings
from typing import Callable, List, Optional from typing import Callable, List, Optional, Union, Tuple, Sequence
import torch import torch
from torch import Tensor from torch import Tensor
from ..utils import _log_api_usage_once from ..utils import _log_api_usage_once, _make_ntuple
interpolate = torch.nn.functional.interpolate interpolate = torch.nn.functional.interpolate
...@@ -70,20 +70,26 @@ class ConvNormActivation(torch.nn.Sequential): ...@@ -70,20 +70,26 @@ class ConvNormActivation(torch.nn.Sequential):
self, self,
in_channels: int, in_channels: int,
out_channels: int, out_channels: int,
kernel_size: int = 3, kernel_size: Union[int, Tuple[int, ...]] = 3,
stride: int = 1, stride: Union[int, Tuple[int, ...]] = 1,
padding: Optional[int] = None, padding: Optional[Union[int, Tuple[int, ...], str]] = None,
groups: int = 1, groups: int = 1,
norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d, norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
dilation: int = 1, dilation: Union[int, Tuple[int, ...]] = 1,
inplace: Optional[bool] = True, inplace: Optional[bool] = True,
bias: Optional[bool] = None, bias: Optional[bool] = None,
conv_layer: Callable[..., torch.nn.Module] = torch.nn.Conv2d, conv_layer: Callable[..., torch.nn.Module] = torch.nn.Conv2d,
) -> None: ) -> None:
if padding is None: if padding is None:
padding = (kernel_size - 1) // 2 * dilation if isinstance(kernel_size, int) and isinstance(dilation, int):
padding = (kernel_size - 1) // 2 * dilation
else:
_conv_dim = len(kernel_size) if isinstance(kernel_size, Sequence) else len(dilation)
kernel_size = _make_ntuple(kernel_size, _conv_dim)
dilation = _make_ntuple(dilation, _conv_dim)
padding = tuple((kernel_size[i] - 1) // 2 * dilation[i] for i in range(_conv_dim))
if bias is None: if bias is None:
bias = norm_layer is None bias = norm_layer is None
...@@ -139,13 +145,13 @@ class Conv2dNormActivation(ConvNormActivation): ...@@ -139,13 +145,13 @@ class Conv2dNormActivation(ConvNormActivation):
self, self,
in_channels: int, in_channels: int,
out_channels: int, out_channels: int,
kernel_size: int = 3, kernel_size: Union[int, Tuple[int, int]] = 3,
stride: int = 1, stride: Union[int, Tuple[int, int]] = 1,
padding: Optional[int] = None, padding: Optional[Union[int, Tuple[int, int], str]] = None,
groups: int = 1, groups: int = 1,
norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d, norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
dilation: int = 1, dilation: Union[int, Tuple[int, int]] = 1,
inplace: Optional[bool] = True, inplace: Optional[bool] = True,
bias: Optional[bool] = None, bias: Optional[bool] = None,
) -> None: ) -> None:
...@@ -188,13 +194,13 @@ class Conv3dNormActivation(ConvNormActivation): ...@@ -188,13 +194,13 @@ class Conv3dNormActivation(ConvNormActivation):
self, self,
in_channels: int, in_channels: int,
out_channels: int, out_channels: int,
kernel_size: int = 3, kernel_size: Union[int, Tuple[int, int, int]] = 3,
stride: int = 1, stride: Union[int, Tuple[int, int, int]] = 1,
padding: Optional[int] = None, padding: Optional[Union[int, Tuple[int, int, int], str]] = None,
groups: int = 1, groups: int = 1,
norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm3d, norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm3d,
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
dilation: int = 1, dilation: Union[int, Tuple[int, int, int]] = 1,
inplace: Optional[bool] = True, inplace: Optional[bool] = True,
bias: Optional[bool] = None, bias: Optional[bool] = None,
) -> None: ) -> None:
......
import collections
import math import math
import pathlib import pathlib
import warnings import warnings
from itertools import repeat
from types import FunctionType from types import FunctionType
from typing import Any, BinaryIO, List, Optional, Tuple, Union from typing import Any, BinaryIO, List, Optional, Tuple, Union
...@@ -569,3 +571,18 @@ def _log_api_usage_once(obj: Any) -> None: ...@@ -569,3 +571,18 @@ def _log_api_usage_once(obj: Any) -> None:
if isinstance(obj, FunctionType): if isinstance(obj, FunctionType):
name = obj.__name__ name = obj.__name__
torch._C._log_api_usage_once(f"{module}.{name}") torch._C._log_api_usage_once(f"{module}.{name}")
def _make_ntuple(x: Any, n: int) -> Tuple[Any, ...]:
"""
Make n-tuple from input x. If x is an iterable, then we just convert it to tuple.
Otherwise we will make a tuple of length n, all with value of x.
reference: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/utils.py#L8
Args:
x (Any): input value
n (int): length of the resulting tuple
"""
if isinstance(x, collections.abc.Iterable):
return tuple(x)
return tuple(repeat(x, n))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment