Unverified Commit f14682a8 authored by YosuaMichael's avatar YosuaMichael Committed by GitHub
Browse files

Generalize ConvNormActivation function to accept tuple for some parameters (#6251)

* Make ConvNormActivation function accept tuple for kernel_size, stride, padding, and dilation

* Fix the method to get the conv_dim

* Simplify if-elif logic
parent 54160313
import warnings
from typing import Callable, List, Optional
from typing import Callable, List, Optional, Union, Tuple, Sequence
import torch
from torch import Tensor
from ..utils import _log_api_usage_once
from ..utils import _log_api_usage_once, _make_ntuple
interpolate = torch.nn.functional.interpolate
......@@ -70,20 +70,26 @@ class ConvNormActivation(torch.nn.Sequential):
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
stride: int = 1,
padding: Optional[int] = None,
kernel_size: Union[int, Tuple[int, ...]] = 3,
stride: Union[int, Tuple[int, ...]] = 1,
padding: Optional[Union[int, Tuple[int, ...], str]] = None,
groups: int = 1,
norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
dilation: int = 1,
dilation: Union[int, Tuple[int, ...]] = 1,
inplace: Optional[bool] = True,
bias: Optional[bool] = None,
conv_layer: Callable[..., torch.nn.Module] = torch.nn.Conv2d,
) -> None:
if padding is None:
padding = (kernel_size - 1) // 2 * dilation
if isinstance(kernel_size, int) and isinstance(dilation, int):
padding = (kernel_size - 1) // 2 * dilation
else:
_conv_dim = len(kernel_size) if isinstance(kernel_size, Sequence) else len(dilation)
kernel_size = _make_ntuple(kernel_size, _conv_dim)
dilation = _make_ntuple(dilation, _conv_dim)
padding = tuple((kernel_size[i] - 1) // 2 * dilation[i] for i in range(_conv_dim))
if bias is None:
bias = norm_layer is None
......@@ -139,13 +145,13 @@ class Conv2dNormActivation(ConvNormActivation):
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
stride: int = 1,
padding: Optional[int] = None,
kernel_size: Union[int, Tuple[int, int]] = 3,
stride: Union[int, Tuple[int, int]] = 1,
padding: Optional[Union[int, Tuple[int, int], str]] = None,
groups: int = 1,
norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
dilation: int = 1,
dilation: Union[int, Tuple[int, int]] = 1,
inplace: Optional[bool] = True,
bias: Optional[bool] = None,
) -> None:
......@@ -188,13 +194,13 @@ class Conv3dNormActivation(ConvNormActivation):
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
stride: int = 1,
padding: Optional[int] = None,
kernel_size: Union[int, Tuple[int, int, int]] = 3,
stride: Union[int, Tuple[int, int, int]] = 1,
padding: Optional[Union[int, Tuple[int, int, int], str]] = None,
groups: int = 1,
norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm3d,
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
dilation: int = 1,
dilation: Union[int, Tuple[int, int, int]] = 1,
inplace: Optional[bool] = True,
bias: Optional[bool] = None,
) -> None:
......
import collections
import math
import pathlib
import warnings
from itertools import repeat
from types import FunctionType
from typing import Any, BinaryIO, List, Optional, Tuple, Union
......@@ -569,3 +571,18 @@ def _log_api_usage_once(obj: Any) -> None:
if isinstance(obj, FunctionType):
name = obj.__name__
torch._C._log_api_usage_once(f"{module}.{name}")
def _make_ntuple(x: Any, n: int) -> Tuple[Any, ...]:
"""
Make n-tuple from input x. If x is an iterable, then we just convert it to tuple.
Otherwise we will make a tuple of length n, all with value of x.
reference: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/utils.py#L8
Args:
x (Any): input value
n (int): length of the resulting tuple
"""
if isinstance(x, collections.abc.Iterable):
return tuple(x)
return tuple(repeat(x, n))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment