Unverified Commit faf6c6cd authored by Rui Xu's avatar Rui Xu Committed by GitHub
Browse files

[Enhancement] Revise the interface of upfirdn2d function (#1195)

* revise the interface of upfirdn2d function

* adopt to_2tuple
parent 5f9e6b61
...@@ -99,6 +99,7 @@ import torch ...@@ -99,6 +99,7 @@ import torch
from torch.autograd import Function from torch.autograd import Function
from torch.nn import functional as F from torch.nn import functional as F
from mmcv.utils import to_2tuple
from ..utils import ext_loader from ..utils import ext_loader
upfirdn2d_ext = ext_loader.load_ext('_ext', ['upfirdn2d']) upfirdn2d_ext = ext_loader.load_ext('_ext', ['upfirdn2d'])
...@@ -249,20 +250,39 @@ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): ...@@ -249,20 +250,39 @@ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
Args: Args:
input (Tensor): Tensor with shape of (n, c, h, w). input (Tensor): Tensor with shape of (n, c, h, w).
kernel (Tensor): Filter kernel. kernel (Tensor): Filter kernel.
up (int, optional): Upsampling factor. Defaults to 1. up (int | tuple[int], optional): Upsampling factor. If given a number,
down (int, optional): Downsampling factor. Defaults to 1. we will use this factor for the both height and width side.
pad (tuple[int], optional): Padding for tensors, (x_pad, y_pad). Defaults to 1.
Defaults to (0, 0). down (int | tuple[int], optional): Downsampling factor. If given a
number, we will use this factor for the both height and width side.
Defaults to 1.
pad (tuple[int], optional): Padding for tensors, (x_pad, y_pad) or
(x_pad_0, x_pad_1, y_pad_0, y_pad_1). Defaults to (0, 0).
Returns: Returns:
Tensor: Tensor after UpFIRDn. Tensor: Tensor after UpFIRDn.
""" """
if input.device.type == 'cpu': if input.device.type == 'cpu':
out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], if len(pad) == 2:
pad[1], pad[0], pad[1]) pad = (pad[0], pad[1], pad[0], pad[1])
up = to_2tuple(up)
down = to_2tuple(down)
out = upfirdn2d_native(input, kernel, up[0], up[1], down[0], down[1],
pad[0], pad[1], pad[2], pad[3])
else: else:
out = UpFirDn2d.apply(input, kernel, (up, up), (down, down), _up = to_2tuple(up)
(pad[0], pad[1], pad[0], pad[1]))
_down = to_2tuple(down)
if len(pad) == 4:
_pad = pad
elif len(pad) == 2:
_pad = (pad[0], pad[1], pad[0], pad[1])
out = UpFirDn2d.apply(input, kernel, _up, _down, _pad)
return out return out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment