Unverified Commit 869dbf1b authored by Yifei Yang's avatar Yifei Yang Committed by GitHub
Browse files

[Feature] Add Ops of StyleGAN3 (#2290)



* add bias_act

* support bias_act

* support filtered_lrelu

* support filtered_lrelu and upfirdn2d

* support conv2d_gradfix and fix filtered_lrelu

* fix lint

* fix lint

* fix c++ lint

* fix part comments

* fix lint

* rm redundant header

* fix upgrade pip

* fix as comment

* fix c++ lint

* fix ci

* fix-ut

* fix as comments

* add grad check

* remove redundant template

* Update mmcv/ops/bias_act.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* add typehint

* fix as comment:

* complete type hints

* fix lint

* add test for conv_gradfix

* add test for conv_gradfix

* fix lint

* modify licenses and ops.md

* add zh op md

* add torch version policy for conv2d_gradfix

* fix lint

* fix as comments

* rename impl

* rm redudant function and add ut

* fix as comment

* fix lint

* fix lint

* fix as comments

* fix lint

* fix ut

* fix as comment

* fix as comment

* fix as comment

---------
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
parent 0fb07d0e
......@@ -2,7 +2,10 @@
In this file, we list the operations with other licenses instead of Apache 2.0. Users should be careful about adopting these operations in any commercial matters.
| Operation | Files | License |
| :--------------: | :---------------------------------------------------------------------------------------------------------------------------------------------------------------: | :------------: |
| upfirdn2d | [mmcv/ops/csrc/pytorch/cuda/upfirdn2d_kernel.cu](https://github.com/open-mmlab/mmcv/blob/master/mmcv/ops/csrc/pytorch/cuda/upfirdn2d_kernel.cu) | NVIDIA License |
| fused_leaky_relu | [mmcv/ops/csrc/pytorch/cuda/fused_bias_leakyrelu_cuda.cu](https://github.com/open-mmlab/mmcv/blob/master/mmcv/ops/csrc/pytorch/cuda/fused_bias_leakyrelu_cuda.cu) | NVIDIA License |
| Operation | Files | License |
| :--------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------: | :------------: |
| upfirdn2d | [mmcv/ops/csrc/pytorch/cuda/upfirdn2d_kernel.cu](https://github.com/open-mmlab/mmcv/tree/2.x/mmcv/ops/csrc/pytorch/cuda/upfirdn2d_kernel.cu) | NVIDIA License |
| fused_leaky_relu | [mmcv/ops/csrc/pytorch/cuda/fused_bias_leakyrelu_cuda.cu](https://github.com/open-mmlab/mmcv/tree/2.x/mmcv/ops/csrc/pytorch/cuda/fused_bias_leakyrelu_cuda.cu) | NVIDIA License |
| bias_act | [mmcv/ops/csrc/pytorch/cuda/bias_act_cuda.cu](https://github.com/open-mmlab/mmcv/tree/2.x/mmcv/ops/csrc/pytorch/cuda/bias_act_cuda.cu) | NVIDIA License |
| filtered_lrelu | [mmcv/ops/csrc/pytorch/cuda/filtered_lrelu.cu](https://github.com/open-mmlab/mmcv/tree/2.x/mmcv/ops/csrc/pytorch/cuda/filtered_lrelu.cu) | NVIDIA License |
| conv2d_gradfix | [mmcv/ops/conv2d_gradfix.py](https://github.com/open-mmlab/mmcv/tree/2.x/mmcv/ops/conv2d_gradfix.py) | NVIDIA License |
......@@ -61,3 +61,6 @@ We implement common ops used in detection, segmentation, etc.
| Voxelization | √ | √ | | | √ |
| PrRoIPool | | √ | | | |
| BezierAlign | √ | √ | | | |
| BiasAct | | √ | | | |
| FilteredLrelu | | √ | | | |
| Conv2dGradfix | | √ | | | |
......@@ -61,3 +61,6 @@ MMCV 提供了检测、分割等任务中常用的算子
| Voxelization | √ | √ | | | √ |
| PrRoIPool | | √ | | | |
| BezierAlign | √ | √ | | | |
| BiasAct | | √ | | | |
| FilteredLrelu | | √ | | | |
| Conv2dGradfix | | √ | | | |
......@@ -4,6 +4,7 @@ from .assign_score_withk import assign_score_withk
from .ball_query import ball_query
from .bbox import bbox_overlaps
from .bezier_align import BezierAlign, bezier_align
from .bias_act import bias_act
from .border_align import BorderAlign, border_align
from .box_iou_quadri import box_iou_quadri
from .box_iou_rotated import box_iou_rotated
......@@ -11,6 +12,7 @@ from .carafe import CARAFE, CARAFENaive, CARAFEPack, carafe, carafe_naive
from .cc_attention import CrissCrossAttention
from .chamfer_distance import chamfer_distance
from .contour_expand import contour_expand
from .conv2d_gradfix import conv2d, conv_transpose2d
from .convex_iou import convex_giou, convex_iou
from .corner_pool import CornerPool
from .correlation import Correlation
......@@ -22,6 +24,7 @@ from .deprecated_wrappers import ConvTranspose2d_deprecated as ConvTranspose2d
from .deprecated_wrappers import Linear_deprecated as Linear
from .deprecated_wrappers import MaxPool2d_deprecated as MaxPool2d
from .diff_iou_rotated import diff_iou_rotated_2d, diff_iou_rotated_3d
from .filtered_lrelu import filtered_lrelu
from .focal_loss import (SigmoidFocalLoss, SoftmaxFocalLoss,
sigmoid_focal_loss, softmax_focal_loss)
from .furthest_point_sample import (furthest_point_sample,
......@@ -68,7 +71,7 @@ from .sync_bn import SyncBatchNorm
from .three_interpolate import three_interpolate
from .three_nn import three_nn
from .tin_shift import TINShift, tin_shift
from .upfirdn2d import upfirdn2d
from .upfirdn2d import filter2d, upfirdn2d, upsample2d
from .voxelize import Voxelization, voxelization
__all__ = [
......@@ -103,5 +106,6 @@ __all__ = [
'points_in_boxes_cpu', 'points_in_boxes_all', 'points_in_polygons',
'min_area_polygons', 'active_rotated_filter', 'convex_iou', 'convex_giou',
'diff_iou_rotated_2d', 'diff_iou_rotated_3d', 'chamfer_distance',
'PrRoIPool', 'prroi_pool', 'BezierAlign', 'bezier_align'
'PrRoIPool', 'prroi_pool', 'bias_act', 'filtered_lrelu', 'conv2d',
'conv_transpose2d', 'filter2d', 'upsample2d', 'BezierAlign', 'bezier_align'
]
# Modified from
# https://github.com/NVlabs/stylegan3/blob/main/torch_utils/ops/bias_act.py
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
# source: https://github.com/open-mmlab/mmediting/blob/dev-1.x/mmedit/models/editors/stylegan3/stylegan3_ops/ops/bias_act.py # noqa
"""Custom PyTorch ops for efficient bias and activation."""
from typing import Any, Dict, Optional, Union
import numpy as np
import torch
from ..utils import ext_loader
ext_module = ext_loader.load_ext('_ext', ['bias_act'])
class EasyDict(dict):
"""Convenience class that behaves like a dict but allows access with the
attribute syntax."""
def __getattr__(self, name: str) -> Any:
try:
return self[name]
except KeyError:
raise AttributeError(name)
def __setattr__(self, name: str, value: Any) -> None:
self[name] = value
def __delattr__(self, name: str) -> None:
del self[name]
activation_funcs = {
'linear':
EasyDict(
func=lambda x, **_: x,
def_alpha=0,
def_gain=1,
cuda_idx=1,
ref='',
has_2nd_grad=False),
'relu':
EasyDict(
func=lambda x, **_: torch.nn.functional.relu(x),
def_alpha=0,
def_gain=np.sqrt(2),
cuda_idx=2,
ref='y',
has_2nd_grad=False),
'lrelu':
EasyDict(
func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha),
def_alpha=0.2,
def_gain=np.sqrt(2),
cuda_idx=3,
ref='y',
has_2nd_grad=False),
'tanh':
EasyDict(
func=lambda x, **_: torch.tanh(x),
def_alpha=0,
def_gain=1,
cuda_idx=4,
ref='y',
has_2nd_grad=True),
'sigmoid':
EasyDict(
func=lambda x, **_: torch.sigmoid(x),
def_alpha=0,
def_gain=1,
cuda_idx=5,
ref='y',
has_2nd_grad=True),
'elu':
EasyDict(
func=lambda x, **_: torch.nn.functional.elu(x),
def_alpha=0,
def_gain=1,
cuda_idx=6,
ref='y',
has_2nd_grad=True),
'selu':
EasyDict(
func=lambda x, **_: torch.nn.functional.selu(x),
def_alpha=0,
def_gain=1,
cuda_idx=7,
ref='y',
has_2nd_grad=True),
'softplus':
EasyDict(
func=lambda x, **_: torch.nn.functional.softplus(x),
def_alpha=0,
def_gain=1,
cuda_idx=8,
ref='y',
has_2nd_grad=True),
'swish':
EasyDict(
func=lambda x, **_: torch.sigmoid(x) * x,
def_alpha=0,
def_gain=np.sqrt(2),
cuda_idx=9,
ref='x',
has_2nd_grad=True),
}
_null_tensor = torch.empty([0])
def bias_act(input: torch.Tensor,
bias: Optional[torch.Tensor] = None,
dim: int = 1,
act: str = 'linear',
alpha: Optional[Union[float, int]] = None,
gain: Optional[float] = None,
clamp: Optional[float] = None,
use_custom_op: bool = True):
r"""Fused bias and activation function.
Adds `bias` to activation tensor `input`, and evaluates activation
function `act`, and scales the result by `gain`. Each of the steps is
optional.
In most cases, the fused op is considerably more efficient than performing
the same calculation using standard PyTorch ops. It supports first and
second order gradients, but not third order gradients.
Args:
input (torch.Tensor): Input activation tensor. Can be of any shape.
bias (torch.Tensor): Bias vector, or `None` to disable.
Must be a 1D tensor of the same type as `input`. The shape must
be known, and it must match the dimension of `input` corresponding
to `dim`. Defaults to None.
dim (int): The dimension in `input` corresponding to the elements of
`bias`. The value of `dim` is ignored if `b` is not specified.
Defaults to 1.
act (str): Name of the activation function to evaluate, or `"linear"`
to disable. Can be e.g. "relu", "lrelu", "tanh", "sigmoid",
"swish", etc. See `activation_funcs` for a full list. `None` is not
allowed. Defaults to `linear`.
alpha (float or int): Shape parameter for the activation
function, or `None` to use the default. Defaults to None.
gain (float): Scaling factor for the output tensor, or `None`
to use default. See `activation_funcs` for the default scaling of
each activation function. If unsure, consider specifying 1.
Defaults to None.
clamp (float): Clamp the output values to `[-clamp, +clamp]`,
or `None` to disable the clamping (default). Defaults to None.
use_custom_op (bool): Whether to use customized op.
Defaults to True.
Returns:
torch.Tensor: Tensor of the same shape and datatype as `input`.
"""
assert isinstance(input, torch.Tensor)
if use_custom_op and input.is_cuda:
return _bias_act_cuda(
dim=dim, act=act, alpha=alpha, gain=gain,
clamp=clamp).apply(input, bias)
return _bias_act_ref(
input=input,
bias=bias,
dim=dim,
act=act,
alpha=alpha,
gain=gain,
clamp=clamp)
def _bias_act_ref(input: torch.Tensor,
bias: Optional[torch.Tensor] = None,
dim: int = 1,
act: str = 'linear',
alpha: Optional[Union[float, int]] = None,
gain: Optional[float] = None,
clamp: Optional[float] = None):
"""Slow reference implementation of `bias_act()` using standard PyTorch
ops.
Adds `bias` to activation tensor `input`, and evaluates activation
function `act`, and scales the result by `gain`. Each of the steps is
optional.
In most cases, the fused op is considerably more efficient than performing
the same calculation using standard PyTorch ops. It supports first and
second order gradients, but not third order gradients.
Args:
input (torch.Tensor): Input activation tensor. Can be of any shape.
bias (torch.Tensor): Bias vector, or `None` to disable.
Must be a 1D tensor of the same type as `input`. The shape must
be known, and it must match the dimension of `input` corresponding
to `dim`. Defaults to None.
dim (int): The dimension in `input` corresponding to the elements of
`bias`. The value of `dim` is ignored if `b` is not specified.
Defaults to 1.
act (str): Name of the activation function to evaluate, or `"linear"`
to disable. Can be e.g. "relu", "lrelu", "tanh", "sigmoid",
"swish", etc. See `activation_funcs` for a full list. `None` is not
allowed. Defaults to `linear`.
alpha (float or int): Shape parameter for the activation
function, or `None` to use the default. Defaults to None.
gain (float): Scaling factor for the output tensor, or `None`
to use default. See `activation_funcs` for the default scaling of
each activation function. If unsure, consider specifying 1.
Defaults to None.
clamp (float): Clamp the output values to
`[-clamp, +clamp]`, or `None` to disable the clamping (default).
Defaults to None.
Returns:
torch.Tensor: Tensor of the same shape and datatype as `input`.
"""
assert isinstance(input, torch.Tensor)
assert clamp is None or clamp >= 0
spec = activation_funcs[act]
alpha = float(alpha if alpha is not None else spec.def_alpha)
gain = float(gain if gain is not None else spec.def_gain)
clamp = float(clamp if clamp is not None else -1)
# Add bias.
if bias is not None:
assert isinstance(bias, torch.Tensor) and bias.ndim == 1
assert 0 <= dim < input.ndim
assert bias.shape[0] == input.shape[dim]
input = input + bias.reshape(
[-1 if i == dim else 1 for i in range(input.ndim)])
# Evaluate activation function.
alpha = float(alpha)
output = spec.func(input, alpha=alpha)
# Scale by gain.
gain = float(gain)
if gain != 1:
output = output * gain
# Clamp.
if clamp >= 0:
# pylint: disable=invalid-unary-operand-type
output = output.clamp(-clamp, clamp)
return output
_bias_act_cuda_cache: Dict = dict()
def _bias_act_cuda(dim: int = 1,
act: str = 'linear',
alpha: Optional[Union[float, int]] = None,
gain: Optional[float] = None,
clamp: Optional[float] = None):
""""Fast CUDA implementation of `bias_act()` using custom ops.
Args:
dim (int): The dimension in `x` corresponding to the elements of `b`.
The value of `dim` is ignored if `b` is not specified.
Defaults to 1.
act (str): Name of the activation function to evaluate, or `"linear"`
to disable. Can be e.g. "relu", "lrelu", "tanh", "sigmoid",
"swish", etc. See `activation_funcs` for a full list. `None` is not
allowed. Defaults to `linear`.
alpha (float | int): Shape parameter for the activation
function, or `None` to use the default. Defaults to None.
gain (float): Scaling factor for the output tensor, or `None`
to use default. See `activation_funcs` for the default scaling of
each activation function. If unsure, consider specifying 1.
Defaults to None.
clamp (float): Clamp the output values to `[-clamp, +clamp]`,
or `None` to disable the clamping (default). Defaults to None.
Returns:
torch.Tensor: Tensor of the same shape and datatype as `x`.
"""
# Parse arguments.
assert clamp is None or clamp >= 0
spec = activation_funcs[act]
alpha = float(alpha if alpha is not None else spec.def_alpha)
gain = float(gain if gain is not None else spec.def_gain)
clamp = float(clamp if clamp is not None else -1)
# Lookup from cache.
key = (dim, act, alpha, gain, clamp)
if key in _bias_act_cuda_cache:
return _bias_act_cuda_cache[key]
# Forward op.
class BiasActCuda(torch.autograd.Function):
@staticmethod
def forward(ctx, x, b): # pylint: disable=arguments-differ
ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride(
1) == 1 else torch.contiguous_format
x = x.contiguous(memory_format=ctx.memory_format)
b = b.contiguous() if b is not None else _null_tensor.to(x.device)
y = x
if act != 'linear' or gain != 1 or clamp >= 0 or (
b is not _null_tensor.to(x.device)):
y = ext_module.bias_act(x, b, _null_tensor.to(x.device),
_null_tensor.to(x.device),
_null_tensor.to(x.device), 0, dim,
spec.cuda_idx, alpha, gain, clamp)
ctx.save_for_backward(
x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor.to(
x.device), b if 'x' in spec.ref or spec.has_2nd_grad else
_null_tensor.to(x.device),
y if 'y' in spec.ref else _null_tensor.to(x.device))
return y
@staticmethod
def backward(ctx, dy): # pylint: disable=arguments-differ
dy = dy.contiguous(memory_format=ctx.memory_format)
x, b, y = ctx.saved_tensors
dx = None
db = None
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
dx = dy
if act != 'linear' or gain != 1 or clamp >= 0:
dx = BiasActCudaGrad.apply(dy, x, b, y)
if ctx.needs_input_grad[1]:
db = dx.sum([i for i in range(dx.ndim) if i != dim])
return dx, db
# Backward op.
class BiasActCudaGrad(torch.autograd.Function):
@staticmethod
def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
ctx.memory_format = torch.channels_last if dy.ndim > 2 and (
dy.stride(1) == 1) else torch.contiguous_format
dx = ext_module.bias_act(dy, b, x, y, _null_tensor.to(x.device), 1,
dim, spec.cuda_idx, alpha, gain, clamp)
ctx.save_for_backward(
dy if spec.has_2nd_grad else _null_tensor.to(x.device), x, b,
y)
return dx
@staticmethod
def backward(ctx, d_dx): # pylint: disable=arguments-differ
d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
dy, x, b, y = ctx.saved_tensors
d_dy = None
d_x = None
d_b = None
d_y = None
if ctx.needs_input_grad[0]:
d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
if spec.has_2nd_grad and (ctx.needs_input_grad[1]
or ctx.needs_input_grad[2]):
d_x = ext_module.bias_act(d_dx, b, x, y, dy, 2, dim,
spec.cuda_idx, alpha, gain, clamp)
if spec.has_2nd_grad and ctx.needs_input_grad[2]:
d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
return d_dy, d_x, d_b, d_y
# Add to cache.
_bias_act_cuda_cache[key] = BiasActCuda
return BiasActCuda
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
# source: https://github.com/NVlabs/stylegan3/blob/main/torch_utils/ops/conv2d_gradfix.py # noqa
"""Custom replacement for `torch.nn.functional.conv2d` that supports
arbitrarily high order gradients with zero performance penalty."""
import contextlib
import warnings
from typing import Dict, Optional, Tuple, Union
import torch
enabled = True
weight_gradients_disabled = False
@contextlib.contextmanager
def no_weight_gradients(disable=True):
global weight_gradients_disabled
old = weight_gradients_disabled
if disable:
weight_gradients_disabled = True
yield
weight_gradients_disabled = old
def conv2d(input: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
stride: Union[int, Tuple[int, ...]] = 1,
padding: Union[int, Tuple[int, ...]] = 0,
dilation: Union[int, Tuple[int, ...]] = 1,
groups: int = 1):
flag = True
if torch.__version__ >= '1.10.0':
warnings.warn('Since '
'aten:cudnn_convolution_backward_weight is '
f'not supported in torch=={torch.__version__},'
' rolling back to `torch.nn.functional.conv2d`')
flag = False
if _should_use_custom_op(input) and flag:
return _conv2d_gradfix(
transpose=False,
weight_shape=weight.shape,
stride=stride,
padding=padding,
output_padding=0,
dilation=dilation,
groups=groups).apply(input, weight, bias)
return torch.nn.functional.conv2d(
input=input,
weight=weight,
bias=bias,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups)
def conv_transpose2d(input: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
stride: Union[int, Tuple[int, ...]] = 1,
padding: Union[int, Tuple[int, ...]] = 0,
output_padding: Union[int, Tuple[int, ...]] = 0,
groups: int = 1,
dilation: Union[int, Tuple[int, ...]] = 1):
if _should_use_custom_op(input):
return _conv2d_gradfix(
transpose=True,
weight_shape=weight.shape,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation).apply(input, weight, bias)
return torch.nn.functional.conv_transpose2d(
input=input,
weight=weight,
bias=bias,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation)
def _should_use_custom_op(input):
assert isinstance(input, torch.Tensor)
if (not enabled) or (not torch.backends.cudnn.enabled):
return False
if input.device.type != 'cuda':
return False
return True
def _to_tuple(x, ndim):
xs = tuple(x) if isinstance(x, (tuple, list)) else (x, ) * ndim
assert len(xs) == ndim
assert all(isinstance(x, int) for x in xs)
return xs
_conv2d_gradfix_cache: Dict = dict()
_null_tensor = torch.empty([0])
def _conv2d_gradfix(
transpose: bool,
weight_shape: Tuple[int, ...],
stride: Union[int, Tuple[int, ...]],
padding: Union[int, Tuple[int, ...]],
output_padding: Union[int, Tuple[int, ...]],
dilation: Union[int, Tuple[int, ...]],
groups: int,
):
# Parse arguments.
ndim = 2
weight_shape = tuple(weight_shape)
stride = _to_tuple(stride, ndim)
padding = _to_tuple(padding, ndim)
output_padding = _to_tuple(output_padding, ndim)
dilation = _to_tuple(dilation, ndim)
# Lookup from cache.
key = (transpose, weight_shape, stride, padding, output_padding, dilation,
groups)
if key in _conv2d_gradfix_cache:
return _conv2d_gradfix_cache[key]
# Validate arguments.
assert groups >= 1
assert len(weight_shape) == ndim + 2
assert all(stride[i] >= 1 for i in range(ndim)) # type: ignore
assert all(padding[i] >= 0 for i in range(ndim)) # type: ignore
assert all(dilation[i] >= 0 for i in range(ndim)) # type: ignore
if not transpose:
assert all(output_padding[i] == 0 for i in range(ndim)) # type: ignore
else: # transpose
for i in range(ndim):
assert 0 <= output_padding[i] < max( # type: ignore
stride[i], # type: ignore
dilation[i]) # type: ignore
# Helpers.
common_kwargs = dict(
stride=stride, padding=padding, dilation=dilation, groups=groups)
def calc_output_padding(input_shape, output_shape):
if transpose:
return [0, 0]
return [
input_shape[i + 2] - (output_shape[i + 2] - 1) * stride[i] -
(1 - 2 * padding[i]) - dilation[i] * (weight_shape[i + 2] - 1)
for i in range(ndim)
]
# Forward & backward.
class Conv2d(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, bias):
assert weight.shape == weight_shape
ctx.save_for_backward(
input if weight.requires_grad else _null_tensor,
weight if input.requires_grad else _null_tensor,
)
ctx.input_shape = input.shape
# Simple 1x1 convolution => cuBLAS (only on Volta, not on Ampere).
if weight_shape[2:] == stride == dilation == (
1, 1) and padding == (
0, 0) and torch.cuda.get_device_capability(
input.device) < (8, 0):
a = weight.reshape(groups, weight_shape[0] // groups,
weight_shape[1])
b = input.reshape(input.shape[0], groups,
input.shape[1] // groups, -1)
c = (a.transpose(1, 2) if transpose else a) @ b.permute(
1, 2, 0, 3).flatten(2)
c = c.reshape(-1, input.shape[0],
*input.shape[2:]).transpose(0, 1)
c = c if bias is None else c + bias.unsqueeze(0).unsqueeze(
2).unsqueeze(3)
return c.contiguous(
memory_format=(torch.channels_last if input.stride(1) ==
1 else torch.contiguous_format))
# General case => cuDNN.
if transpose:
return torch.nn.functional.conv_transpose2d(
input=input,
weight=weight,
bias=bias,
output_padding=output_padding,
**common_kwargs)
return torch.nn.functional.conv2d(
input=input, weight=weight, bias=bias, **common_kwargs)
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
input_shape = ctx.input_shape
grad_input = None
grad_weight = None
grad_bias = None
if ctx.needs_input_grad[0]:
p = calc_output_padding(
input_shape=input_shape, output_shape=grad_output.shape)
op = _conv2d_gradfix(
transpose=(not transpose),
weight_shape=weight_shape,
output_padding=p,
**common_kwargs)
grad_input = op.apply(grad_output, weight, None)
assert grad_input.shape == input_shape
if ctx.needs_input_grad[1] and not weight_gradients_disabled:
grad_weight = Conv2dGradWeight.apply(grad_output, input)
assert grad_weight.shape == weight_shape
if ctx.needs_input_grad[2]:
grad_bias = grad_output.sum([0, 2, 3])
return grad_input, grad_weight, grad_bias
# Gradient with respect to the weights.
class Conv2dGradWeight(torch.autograd.Function):
@staticmethod
def forward(ctx, grad_output, input):
ctx.save_for_backward(
grad_output if input.requires_grad else _null_tensor,
input if grad_output.requires_grad else _null_tensor,
)
ctx.grad_output_shape = grad_output.shape
ctx.input_shape = input.shape
# Simple 1x1 convolution => cuBLAS (on both Volta and Ampere).
if weight_shape[2:] == stride == dilation == (
1, 1) and padding == (0, 0):
a = grad_output.reshape(grad_output.shape[0], groups,
grad_output.shape[1] // groups,
-1).permute(1, 2, 0, 3).flatten(2)
b = input.reshape(input.shape[0], groups,
input.shape[1] // groups,
-1).permute(1, 2, 0, 3).flatten(2)
c = (b @ a.transpose(1, 2) if transpose else
a @ b.transpose(1, 2)).reshape(weight_shape)
return c.contiguous(
memory_format=(torch.channels_last if input.stride(1) ==
1 else torch.contiguous_format))
# General case => cuDNN.
name = ('aten::cudnn_convolution_transpose_backward_weight' if
transpose else 'aten::cudnn_convolution_backward_weight')
flags = [
torch.backends.cudnn.benchmark,
torch.backends.cudnn.deterministic,
torch.backends.cudnn.allow_tf32
]
return torch._C._jit_get_operation(name)(weight_shape, grad_output,
input, padding, stride,
dilation, groups, *flags)
@staticmethod
def backward(ctx, grad2_grad_weight):
grad_output, input = ctx.saved_tensors
grad_output_shape = ctx.grad_output_shape
input_shape = ctx.input_shape
grad2_grad_output = None
grad2_input = None
if ctx.needs_input_grad[0]:
grad2_grad_output = Conv2d.apply(input, grad2_grad_weight,
None)
assert grad2_grad_output.shape == grad_output_shape
if ctx.needs_input_grad[1]:
p = calc_output_padding(
input_shape=input_shape, output_shape=grad_output_shape)
op = _conv2d_gradfix(
transpose=(not transpose),
weight_shape=weight_shape,
output_padding=p,
**common_kwargs)
grad2_input = op.apply(grad_output, grad2_grad_weight, None)
assert grad2_input.shape == input_shape
return grad2_grad_output, grad2_input
_conv2d_gradfix_cache[key] = Conv2d
return Conv2d
#include "pytorch_cpp_helper.hpp"
#include "pytorch_device_registry.hpp"
torch::Tensor bias_act_op_impl(const torch::Tensor &input,
const torch::Tensor &bias,
const torch::Tensor &xref,
const torch::Tensor &yref,
const torch::Tensor &dy, int grad, int dim,
int act, float alpha, float gain, float clamp) {
return DISPATCH_DEVICE_IMPL(bias_act_op_impl, input, bias, xref, yref, dy,
grad, dim, act, alpha, gain, clamp);
}
torch::Tensor bias_act(const torch::Tensor &input, const torch::Tensor &bias,
const torch::Tensor &xref, const torch::Tensor &yref,
const torch::Tensor &dy, int grad, int dim, int act,
float alpha, float gain, float clamp) {
return bias_act_op_impl(input, bias, xref, yref, dy, grad, dim, act, alpha,
gain, clamp);
}
// Modified from
// https://github.com/NVlabs/stylegan3/blob/main/torch_utils/ops/bias_act.cpp
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
#include <c10/util/Half.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
#include "pytorch_cuda_helper.hpp"
struct bias_act_kernel_params {
const void* x; // [sizeX]
const void* b; // [sizeB] or NULL
const void* xref; // [sizeX] or NULL
const void* yref; // [sizeX] or NULL
const void* dy; // [sizeX] or NULL
void* y; // [sizeX]
int grad;
int act;
float alpha;
float gain;
float clamp;
int sizeX;
int sizeB;
int stepB;
int loopX;
};
// CUDA kernel selection.
template <class T>
void* choose_bias_act_kernel(const bias_act_kernel_params& p);
//------------------------------------------------------------------------
// Helpers.
template <class T>
struct InternalType;
template <>
struct InternalType<double> {
typedef double scalar_t;
};
template <>
struct InternalType<float> {
typedef float scalar_t;
};
template <>
struct InternalType<c10::Half> {
typedef float scalar_t;
};
//------------------------------------------------------------------------
// CUDA kernel.
template <class T, int A>
__global__ void bias_act_kernel(bias_act_kernel_params p) {
typedef typename InternalType<T>::scalar_t scalar_t;
int G = p.grad;
scalar_t alpha = (scalar_t)p.alpha;
scalar_t gain = (scalar_t)p.gain;
scalar_t clamp = (scalar_t)p.clamp;
scalar_t one = (scalar_t)1;
scalar_t two = (scalar_t)2;
scalar_t expRange = (scalar_t)80;
scalar_t halfExpRange = (scalar_t)40;
scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
// Loop over elements.
int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX;
loopIdx++, xi += blockDim.x) {
// Load.
scalar_t x = (scalar_t)((const T*)p.x)[xi];
scalar_t b =
(p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
scalar_t yy = (gain != 0) ? yref / gain : 0;
scalar_t y = 0;
// Apply bias.
((G == 0) ? x : xref) += b;
// linear
if (A == 1) {
if (G == 0) y = x;
if (G == 1) y = x;
}
// relu
if (A == 2) {
if (G == 0) y = (x > 0) ? x : 0;
if (G == 1) y = (yy > 0) ? x : 0;
}
// lrelu
if (A == 3) {
if (G == 0) y = (x > 0) ? x : x * alpha;
if (G == 1) y = (yy > 0) ? x : x * alpha;
}
// tanh
if (A == 4) {
if (G == 0) {
scalar_t c = exp(x);
scalar_t d = one / c;
y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d);
}
if (G == 1) y = x * (one - yy * yy);
if (G == 2) y = x * (one - yy * yy) * (-two * yy);
}
// sigmoid
if (A == 5) {
if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
if (G == 1) y = x * yy * (one - yy);
if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
}
// elu
if (A == 6) {
if (G == 0) y = (x >= 0) ? x : exp(x) - one;
if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
}
// selu
if (A == 7) {
if (G == 0)
y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
if (G == 1)
y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
}
// softplus
if (A == 8) {
if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
if (G == 1) y = x * (one - exp(-yy));
if (G == 2) {
scalar_t c = exp(-yy);
y = x * c * (one - c);
}
}
// swish
if (A == 9) {
if (G == 0)
y = (x < -expRange) ? 0 : x / (exp(-x) + one);
else {
scalar_t c = exp(xref);
scalar_t d = c + one;
if (G == 1)
y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
else
y = (xref > halfExpRange)
? 0
: x * c * (xref * (two - d) + two * d) / (d * d * d);
yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
}
}
// Apply gain.
y *= gain * dy;
// Clamp.
if (clamp >= 0) {
if (G == 0)
y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
else
y = (yref > -clamp & yref < clamp) ? y : 0;
}
// Store.
((T*)p.y)[xi] = (T)y;
}
}
//------------------------------------------------------------------------
// CUDA kernel selection.
template <class T>
void* choose_bias_act_kernel(const bias_act_kernel_params& p) {
if (p.act == 1) return (void*)bias_act_kernel<T, 1>;
if (p.act == 2) return (void*)bias_act_kernel<T, 2>;
if (p.act == 3) return (void*)bias_act_kernel<T, 3>;
if (p.act == 4) return (void*)bias_act_kernel<T, 4>;
if (p.act == 5) return (void*)bias_act_kernel<T, 5>;
if (p.act == 6) return (void*)bias_act_kernel<T, 6>;
if (p.act == 7) return (void*)bias_act_kernel<T, 7>;
if (p.act == 8) return (void*)bias_act_kernel<T, 8>;
if (p.act == 9) return (void*)bias_act_kernel<T, 9>;
return NULL;
}
//------------------------------------------------------------------------
static bool has_same_layout(torch::Tensor x, torch::Tensor y) {
if (x.dim() != y.dim()) return false;
for (int64_t i = 0; i < x.dim(); i++) {
if (x.size(i) != y.size(i)) return false;
if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) return false;
}
return true;
}
//------------------------------------------------------------------------
torch::Tensor bias_act_op(const torch::Tensor& x, const torch::Tensor& b,
const torch::Tensor& xref, const torch::Tensor& yref,
const torch::Tensor& dy, int grad, int dim, int act,
float alpha, float gain, float clamp) {
// Validate arguments.
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
TORCH_CHECK(
b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()),
"b must have the same dtype and device as x");
TORCH_CHECK(xref.numel() == 0 ||
(xref.sizes() == x.sizes() && xref.dtype() == x.dtype() &&
xref.device() == x.device()),
"xref must have the same shape, dtype, and device as x");
TORCH_CHECK(yref.numel() == 0 ||
(yref.sizes() == x.sizes() && yref.dtype() == x.dtype() &&
yref.device() == x.device()),
"yref must have the same shape, dtype, and device as x");
TORCH_CHECK(
dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() &&
dy.device() == x.device()),
"dy must have the same dtype and device as x");
TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
TORCH_CHECK(b.dim() == 1, "b must have rank 1");
TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()),
"dim is out of bounds");
TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim),
"b has wrong number of elements");
TORCH_CHECK(grad >= 0, "grad must be non-negative");
// Validate layout.
TORCH_CHECK(x.is_non_overlapping_and_dense(),
"x must be non-overlapping and dense");
TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x),
"xref must have the same layout as x");
TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x),
"yref must have the same layout as x");
TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x),
"dy must have the same layout as x");
// Create output tensor.
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
torch::Tensor y = torch::empty_like(x);
TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
// Initialize CUDA kernel parameters.
bias_act_kernel_params p;
p.x = x.data_ptr();
p.b = (b.numel()) ? b.data_ptr() : NULL;
p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
p.y = y.data_ptr();
p.grad = grad;
p.act = act;
p.alpha = alpha;
p.gain = gain;
p.clamp = clamp;
p.sizeX = (int)x.numel();
p.sizeB = (int)b.numel();
p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
// Choose CUDA kernel.
void* kernel;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
kernel = choose_bias_act_kernel<scalar_t>(p);
});
TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
// Launch CUDA kernel.
p.loopX = 4;
int blockSize = 4 * 32;
int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
void* args[] = {&p};
AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0,
at::cuda::getCurrentCUDAStream()));
return y;
}
......@@ -3,45 +3,45 @@
void AssignScoreWithKForwardCUDAKernelLauncher(
int B, int N0, int N1, int M, int K, int O, int aggregate,
const Tensor& points, const Tensor& centers, const Tensor& scores,
const Tensor& knn_idx, Tensor& output);
const Tensor &points, const Tensor &centers, const Tensor &scores,
const Tensor &knn_idx, Tensor &output);
void AssignScoreWithKBackwardCUDAKernelLauncher(
int B, int N0, int N1, int M, int K, int O, int aggregate,
const Tensor& grad_out, const Tensor& points, const Tensor& centers,
const Tensor& scores, const Tensor& knn_idx, Tensor& grad_points,
Tensor& grad_centers, Tensor& grad_scores);
const Tensor &grad_out, const Tensor &points, const Tensor &centers,
const Tensor &scores, const Tensor &knn_idx, Tensor &grad_points,
Tensor &grad_centers, Tensor &grad_scores);
void assign_score_withk_forward_cuda(int B, int N0, int N1, int M, int K, int O,
int aggregate, const Tensor& points,
const Tensor& centers,
const Tensor& scores,
const Tensor& knn_idx, Tensor& output) {
int aggregate, const Tensor &points,
const Tensor &centers,
const Tensor &scores,
const Tensor &knn_idx, Tensor &output) {
AssignScoreWithKForwardCUDAKernelLauncher(
B, N0, N1, M, K, O, aggregate, points, centers, scores, knn_idx, output);
};
void assign_score_withk_backward_cuda(
int B, int N0, int N1, int M, int K, int O, int aggregate,
const Tensor& grad_out, const Tensor& points, const Tensor& centers,
const Tensor& scores, const Tensor& knn_idx, Tensor& grad_points,
Tensor& grad_centers, Tensor& grad_scores) {
const Tensor &grad_out, const Tensor &points, const Tensor &centers,
const Tensor &scores, const Tensor &knn_idx, Tensor &grad_points,
Tensor &grad_centers, Tensor &grad_scores) {
AssignScoreWithKBackwardCUDAKernelLauncher(
B, N0, N1, M, K, O, aggregate, grad_out, points, centers, scores, knn_idx,
grad_points, grad_centers, grad_scores);
};
void assign_score_withk_forward_impl(int B, int N0, int N1, int M, int K, int O,
int aggregate, const Tensor& points,
const Tensor& centers,
const Tensor& scores,
const Tensor& knn_idx, Tensor& output);
int aggregate, const Tensor &points,
const Tensor &centers,
const Tensor &scores,
const Tensor &knn_idx, Tensor &output);
void assign_score_withk_backward_impl(
int B, int N0, int N1, int M, int K, int O, int aggregate,
const Tensor& grad_out, const Tensor& points, const Tensor& centers,
const Tensor& scores, const Tensor& knn_idx, Tensor& grad_points,
Tensor& grad_centers, Tensor& grad_scores);
const Tensor &grad_out, const Tensor &points, const Tensor &centers,
const Tensor &scores, const Tensor &knn_idx, Tensor &grad_points,
Tensor &grad_centers, Tensor &grad_scores);
REGISTER_DEVICE_IMPL(assign_score_withk_forward_impl, CUDA,
assign_score_withk_forward_cuda);
......@@ -104,37 +104,37 @@ void bbox_overlaps_impl(const Tensor bboxes1, const Tensor bboxes2, Tensor ious,
const int mode, const bool aligned, const int offset);
REGISTER_DEVICE_IMPL(bbox_overlaps_impl, CUDA, bbox_overlaps_cuda);
void BorderAlignForwardCUDAKernelLauncher(const Tensor& input,
const Tensor& boxes, Tensor output,
void BorderAlignForwardCUDAKernelLauncher(const Tensor &input,
const Tensor &boxes, Tensor output,
Tensor argmax_idx,
const int pool_size);
void BorderAlignBackwardCUDAKernelLauncher(const Tensor& grad_output,
const Tensor& boxes,
const Tensor& argmax_idx,
void BorderAlignBackwardCUDAKernelLauncher(const Tensor &grad_output,
const Tensor &boxes,
const Tensor &argmax_idx,
Tensor grad_input,
const int pool_size);
void border_align_forward_cuda(const Tensor& input, const Tensor& boxes,
void border_align_forward_cuda(const Tensor &input, const Tensor &boxes,
Tensor output, Tensor argmax_idx,
const int pool_size) {
BorderAlignForwardCUDAKernelLauncher(input, boxes, output, argmax_idx,
pool_size);
}
void border_align_backward_cuda(const Tensor& grad_output, const Tensor& boxes,
const Tensor& argmax_idx, Tensor grad_input,
void border_align_backward_cuda(const Tensor &grad_output, const Tensor &boxes,
const Tensor &argmax_idx, Tensor grad_input,
const int pool_size) {
BorderAlignBackwardCUDAKernelLauncher(grad_output, boxes, argmax_idx,
grad_input, pool_size);
}
void border_align_forward_impl(const Tensor& input, const Tensor& boxes,
void border_align_forward_impl(const Tensor &input, const Tensor &boxes,
Tensor output, Tensor argmax_idx,
const int pool_size);
void border_align_backward_impl(const Tensor& grad_output, const Tensor& boxes,
const Tensor& argmax_idx, Tensor grad_input,
void border_align_backward_impl(const Tensor &grad_output, const Tensor &boxes,
const Tensor &argmax_idx, Tensor grad_input,
const int pool_size);
REGISTER_DEVICE_IMPL(border_align_forward_impl, CUDA,
......@@ -472,18 +472,18 @@ REGISTER_DEVICE_IMPL(softmax_focal_loss_backward_impl, CUDA,
softmax_focal_loss_backward_cuda);
void FurthestPointSamplingForwardCUDAKernelLauncher(int b, int n, int m,
const float* dataset,
float* temp, int* idxs);
const float *dataset,
float *temp, int *idxs);
void FurthestPointSamplingWithDistForwardCUDAKernelLauncher(
int b, int n, int m, const float* dataset, float* temp, int* idxs);
int b, int n, int m, const float *dataset, float *temp, int *idxs);
void furthest_point_sampling_forward_cuda(Tensor points_tensor,
Tensor temp_tensor, Tensor idx_tensor,
int b, int n, int m) {
const float* dataset = points_tensor.data_ptr<float>();
float* temp = temp_tensor.data_ptr<float>();
int* idxs = idx_tensor.data_ptr<int>();
const float *dataset = points_tensor.data_ptr<float>();
float *temp = temp_tensor.data_ptr<float>();
int *idxs = idx_tensor.data_ptr<int>();
FurthestPointSamplingForwardCUDAKernelLauncher(b, n, m, dataset, temp, idxs);
}
......@@ -491,9 +491,9 @@ void furthest_point_sampling_with_dist_forward_cuda(Tensor points_tensor,
Tensor temp_tensor,
Tensor idx_tensor, int b,
int n, int m) {
const float* dataset = points_tensor.data_ptr<float>();
float* temp = temp_tensor.data_ptr<float>();
int* idxs = idx_tensor.data_ptr<int>();
const float *dataset = points_tensor.data_ptr<float>();
float *temp = temp_tensor.data_ptr<float>();
int *idxs = idx_tensor.data_ptr<int>();
FurthestPointSamplingWithDistForwardCUDAKernelLauncher(b, n, m, dataset, temp,
idxs);
}
......@@ -512,18 +512,57 @@ REGISTER_DEVICE_IMPL(furthest_point_sampling_forward_impl, CUDA,
REGISTER_DEVICE_IMPL(furthest_point_sampling_with_dist_forward_impl, CUDA,
furthest_point_sampling_with_dist_forward_cuda);
torch::Tensor fused_bias_leakyrelu_op(const torch::Tensor& input,
const torch::Tensor& bias,
const torch::Tensor& refer, int act,
torch::Tensor fused_bias_leakyrelu_op(const torch::Tensor &input,
const torch::Tensor &bias,
const torch::Tensor &refer, int act,
int grad, float alpha, float scale);
torch::Tensor fused_bias_leakyrelu_op_impl(const torch::Tensor& input,
const torch::Tensor& bias,
const torch::Tensor& refer, int act,
torch::Tensor fused_bias_leakyrelu_op_impl(const torch::Tensor &input,
const torch::Tensor &bias,
const torch::Tensor &refer, int act,
int grad, float alpha, float scale);
REGISTER_DEVICE_IMPL(fused_bias_leakyrelu_op_impl, CUDA,
fused_bias_leakyrelu_op);
torch::Tensor bias_act_op_impl(const torch::Tensor &input,
const torch::Tensor &bias,
const torch::Tensor &xref,
const torch::Tensor &yref,
const torch::Tensor &dy, int grad, int dim,
int act, float alpha, float gain, float clamp);
torch::Tensor bias_act_op(const torch::Tensor &input, const torch::Tensor &bias,
const torch::Tensor &xref, const torch::Tensor &yref,
const torch::Tensor &dy, int grad, int dim, int act,
float alpha, float gain, float clamp);
REGISTER_DEVICE_IMPL(bias_act_op_impl, CUDA, bias_act_op);
std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu_op_impl(
torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b,
torch::Tensor si, int up, int down, int px0, int px1, int py0, int py1,
int sx, int sy, float gain, float slope, float clamp, bool flip_filters,
bool writeSigns);
std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu_op(
torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b,
torch::Tensor si, int up, int down, int px0, int px1, int py0, int py1,
int sx, int sy, float gain, float slope, float clamp, bool flip_filters,
bool writeSigns);
REGISTER_DEVICE_IMPL(filtered_lrelu_op_impl, CUDA, filtered_lrelu_op);
torch::Tensor filtered_lrelu_act_op_impl(torch::Tensor x, torch::Tensor si,
int sx, int sy, float gain,
float slope, float clamp,
bool writeSigns);
torch::Tensor filtered_lrelu_act_op(torch::Tensor x, torch::Tensor si, int sx,
int sy, float gain, float slope,
float clamp, bool writeSigns);
REGISTER_DEVICE_IMPL(filtered_lrelu_act_op_impl, CUDA, filtered_lrelu_act_op);
void GatherPointsForwardCUDAKernelLauncher(int b, int c, int n, int npoints,
const Tensor points,
const Tensor idx, Tensor out);
......@@ -651,12 +690,12 @@ void IoU3DBoxesOverlapBevForwardCUDAKernelLauncher(const int num_a,
const Tensor boxes_b,
Tensor ans_overlap);
void IoU3DNMS3DForwardCUDAKernelLauncher(const Tensor boxes, Tensor& keep,
Tensor& keep_num,
void IoU3DNMS3DForwardCUDAKernelLauncher(const Tensor boxes, Tensor &keep,
Tensor &keep_num,
float nms_overlap_thresh);
void IoU3DNMS3DNormalForwardCUDAKernelLauncher(const Tensor boxes, Tensor& keep,
Tensor& keep_num,
void IoU3DNMS3DNormalForwardCUDAKernelLauncher(const Tensor boxes, Tensor &keep,
Tensor &keep_num,
float nms_overlap_thresh);
void iou3d_boxes_overlap_bev_forward_cuda(const int num_a, const Tensor boxes_a,
......@@ -666,14 +705,14 @@ void iou3d_boxes_overlap_bev_forward_cuda(const int num_a, const Tensor boxes_a,
ans_overlap);
};
void iou3d_nms3d_forward_cuda(const Tensor boxes, Tensor& keep,
Tensor& keep_num, float nms_overlap_thresh) {
void iou3d_nms3d_forward_cuda(const Tensor boxes, Tensor &keep,
Tensor &keep_num, float nms_overlap_thresh) {
IoU3DNMS3DForwardCUDAKernelLauncher(boxes, keep, keep_num,
nms_overlap_thresh);
};
void iou3d_nms3d_normal_forward_cuda(const Tensor boxes, Tensor& keep,
Tensor& keep_num,
void iou3d_nms3d_normal_forward_cuda(const Tensor boxes, Tensor &keep,
Tensor &keep_num,
float nms_overlap_thresh) {
IoU3DNMS3DNormalForwardCUDAKernelLauncher(boxes, keep, keep_num,
nms_overlap_thresh);
......@@ -683,11 +722,11 @@ void iou3d_boxes_overlap_bev_forward_impl(const int num_a, const Tensor boxes_a,
const int num_b, const Tensor boxes_b,
Tensor ans_overlap);
void iou3d_nms3d_forward_impl(const Tensor boxes, Tensor& keep,
Tensor& keep_num, float nms_overlap_thresh);
void iou3d_nms3d_forward_impl(const Tensor boxes, Tensor &keep,
Tensor &keep_num, float nms_overlap_thresh);
void iou3d_nms3d_normal_forward_impl(const Tensor boxes, Tensor& keep,
Tensor& keep_num,
void iou3d_nms3d_normal_forward_impl(const Tensor boxes, Tensor &keep,
Tensor &keep_num,
float nms_overlap_thresh);
REGISTER_DEVICE_IMPL(iou3d_boxes_overlap_bev_forward_impl, CUDA,
......@@ -812,31 +851,31 @@ REGISTER_DEVICE_IMPL(modulated_deformable_col2im_impl, CUDA,
REGISTER_DEVICE_IMPL(modulated_deformable_col2im_coord_impl, CUDA,
modulated_deformable_col2im_coord_cuda);
Tensor ms_deform_attn_cuda_forward(const Tensor& value,
const Tensor& spatial_shapes,
const Tensor& level_start_index,
const Tensor& sampling_loc,
const Tensor& attn_weight,
Tensor ms_deform_attn_cuda_forward(const Tensor &value,
const Tensor &spatial_shapes,
const Tensor &level_start_index,
const Tensor &sampling_loc,
const Tensor &attn_weight,
const int im2col_step);
void ms_deform_attn_cuda_backward(
const Tensor& value, const Tensor& spatial_shapes,
const Tensor& level_start_index, const Tensor& sampling_loc,
const Tensor& attn_weight, const Tensor& grad_output, Tensor& grad_value,
Tensor& grad_sampling_loc, Tensor& grad_attn_weight, const int im2col_step);
Tensor ms_deform_attn_impl_forward(const Tensor& value,
const Tensor& spatial_shapes,
const Tensor& level_start_index,
const Tensor& sampling_loc,
const Tensor& attn_weight,
const Tensor &value, const Tensor &spatial_shapes,
const Tensor &level_start_index, const Tensor &sampling_loc,
const Tensor &attn_weight, const Tensor &grad_output, Tensor &grad_value,
Tensor &grad_sampling_loc, Tensor &grad_attn_weight, const int im2col_step);
Tensor ms_deform_attn_impl_forward(const Tensor &value,
const Tensor &spatial_shapes,
const Tensor &level_start_index,
const Tensor &sampling_loc,
const Tensor &attn_weight,
const int im2col_step);
void ms_deform_attn_impl_backward(
const Tensor& value, const Tensor& spatial_shapes,
const Tensor& level_start_index, const Tensor& sampling_loc,
const Tensor& attn_weight, const Tensor& grad_output, Tensor& grad_value,
Tensor& grad_sampling_loc, Tensor& grad_attn_weight, const int im2col_step);
const Tensor &value, const Tensor &spatial_shapes,
const Tensor &level_start_index, const Tensor &sampling_loc,
const Tensor &attn_weight, const Tensor &grad_output, Tensor &grad_value,
Tensor &grad_sampling_loc, Tensor &grad_attn_weight, const int im2col_step);
REGISTER_DEVICE_IMPL(ms_deform_attn_impl_forward, CUDA,
ms_deform_attn_cuda_forward);
......@@ -1244,26 +1283,26 @@ REGISTER_DEVICE_IMPL(roi_pool_backward_impl, CUDA, roi_pool_backward_cuda);
typedef enum { SUM = 0, MEAN = 1, MAX = 2 } reduce_t;
std::vector<at::Tensor> DynamicPointToVoxelForwardCUDAKernelLauncher(
const at::Tensor& feats, const at::Tensor& coors,
const at::Tensor &feats, const at::Tensor &coors,
const reduce_t reduce_type);
void DynamicPointToVoxelBackwardCUDAKernelLauncher(
at::Tensor& grad_feats, const at::Tensor& grad_reduced_feats,
const at::Tensor& feats, const at::Tensor& reduced_feats,
const at::Tensor& coors_map, const at::Tensor& reduce_count,
at::Tensor &grad_feats, const at::Tensor &grad_reduced_feats,
const at::Tensor &feats, const at::Tensor &reduced_feats,
const at::Tensor &coors_map, const at::Tensor &reduce_count,
const reduce_t reduce_type);
std::vector<torch::Tensor> dynamic_point_to_voxel_forward_cuda(
const torch::Tensor& feats, const torch::Tensor& coors,
const torch::Tensor &feats, const torch::Tensor &coors,
const reduce_t reduce_type) {
return DynamicPointToVoxelForwardCUDAKernelLauncher(feats, coors,
reduce_type);
};
void dynamic_point_to_voxel_backward_cuda(
torch::Tensor& grad_feats, const torch::Tensor& grad_reduced_feats,
const torch::Tensor& feats, const torch::Tensor& reduced_feats,
const torch::Tensor& coors_idx, const torch::Tensor& reduce_count,
torch::Tensor &grad_feats, const torch::Tensor &grad_reduced_feats,
const torch::Tensor &feats, const torch::Tensor &reduced_feats,
const torch::Tensor &coors_idx, const torch::Tensor &reduce_count,
const reduce_t reduce_type) {
DynamicPointToVoxelBackwardCUDAKernelLauncher(grad_feats, grad_reduced_feats,
feats, reduced_feats, coors_idx,
......@@ -1271,13 +1310,13 @@ void dynamic_point_to_voxel_backward_cuda(
};
std::vector<torch::Tensor> dynamic_point_to_voxel_forward_impl(
const torch::Tensor& feats, const torch::Tensor& coors,
const torch::Tensor &feats, const torch::Tensor &coors,
const reduce_t reduce_type);
void dynamic_point_to_voxel_backward_impl(
torch::Tensor& grad_feats, const torch::Tensor& grad_reduced_feats,
const torch::Tensor& feats, const torch::Tensor& reduced_feats,
const torch::Tensor& coors_idx, const torch::Tensor& reduce_count,
torch::Tensor &grad_feats, const torch::Tensor &grad_reduced_feats,
const torch::Tensor &feats, const torch::Tensor &reduced_feats,
const torch::Tensor &coors_idx, const torch::Tensor &reduce_count,
const reduce_t reduce_type);
REGISTER_DEVICE_IMPL(dynamic_point_to_voxel_forward_impl, CUDA,
......@@ -1443,37 +1482,36 @@ void tin_shift_backward_impl(Tensor grad_output, Tensor shift,
REGISTER_DEVICE_IMPL(tin_shift_forward_impl, CUDA, tin_shift_forward_cuda);
REGISTER_DEVICE_IMPL(tin_shift_backward_impl, CUDA, tin_shift_backward_cuda);
torch::Tensor upfirdn2d_op(const torch::Tensor& input,
const torch::Tensor& kernel, int up_x, int up_y,
int down_x, int down_y, int pad_x0, int pad_x1,
int pad_y0, int pad_y1);
torch::Tensor upfirdn2d_op(torch::Tensor input, torch::Tensor filter, int upx,
int upy, int downx, int downy, int padx0, int padx1,
int pady0, int pady1, bool flip, float gain);
torch::Tensor upfirdn2d_op_impl(const torch::Tensor& input,
const torch::Tensor& kernel, int up_x, int up_y,
int down_x, int down_y, int pad_x0, int pad_x1,
int pad_y0, int pad_y1);
torch::Tensor upfirdn2d_op_impl(torch::Tensor input, torch::Tensor filter,
int upx, int upy, int downx, int downy,
int padx0, int padx1, int pady0, int pady1,
bool flip, float gain);
REGISTER_DEVICE_IMPL(upfirdn2d_op_impl, CUDA, upfirdn2d_op);
int HardVoxelizeForwardCUDAKernelLauncher(
const at::Tensor& points, at::Tensor& voxels, at::Tensor& coors,
at::Tensor& num_points_per_voxel, const std::vector<float> voxel_size,
const at::Tensor &points, at::Tensor &voxels, at::Tensor &coors,
at::Tensor &num_points_per_voxel, const std::vector<float> voxel_size,
const std::vector<float> coors_range, const int max_points,
const int max_voxels, const int NDim = 3);
int NondeterministicHardVoxelizeForwardCUDAKernelLauncher(
const at::Tensor& points, at::Tensor& voxels, at::Tensor& coors,
at::Tensor& num_points_per_voxel, const std::vector<float> voxel_size,
const at::Tensor &points, at::Tensor &voxels, at::Tensor &coors,
at::Tensor &num_points_per_voxel, const std::vector<float> voxel_size,
const std::vector<float> coors_range, const int max_points,
const int max_voxels, const int NDim = 3);
void DynamicVoxelizeForwardCUDAKernelLauncher(
const at::Tensor& points, at::Tensor& coors,
const at::Tensor &points, at::Tensor &coors,
const std::vector<float> voxel_size, const std::vector<float> coors_range,
const int NDim = 3);
int hard_voxelize_forward_cuda(const at::Tensor& points, at::Tensor& voxels,
at::Tensor& coors,
at::Tensor& num_points_per_voxel,
int hard_voxelize_forward_cuda(const at::Tensor &points, at::Tensor &voxels,
at::Tensor &coors,
at::Tensor &num_points_per_voxel,
const std::vector<float> voxel_size,
const std::vector<float> coors_range,
const int max_points, const int max_voxels,
......@@ -1484,8 +1522,8 @@ int hard_voxelize_forward_cuda(const at::Tensor& points, at::Tensor& voxels,
};
int nondeterministic_hard_voxelize_forward_cuda(
const at::Tensor& points, at::Tensor& voxels, at::Tensor& coors,
at::Tensor& num_points_per_voxel, const std::vector<float> voxel_size,
const at::Tensor &points, at::Tensor &voxels, at::Tensor &coors,
at::Tensor &num_points_per_voxel, const std::vector<float> voxel_size,
const std::vector<float> coors_range, const int max_points,
const int max_voxels, const int NDim) {
return NondeterministicHardVoxelizeForwardCUDAKernelLauncher(
......@@ -1493,7 +1531,7 @@ int nondeterministic_hard_voxelize_forward_cuda(
max_points, max_voxels, NDim);
};
void dynamic_voxelize_forward_cuda(const at::Tensor& points, at::Tensor& coors,
void dynamic_voxelize_forward_cuda(const at::Tensor &points, at::Tensor &coors,
const std::vector<float> voxel_size,
const std::vector<float> coors_range,
const int NDim) {
......@@ -1501,21 +1539,21 @@ void dynamic_voxelize_forward_cuda(const at::Tensor& points, at::Tensor& coors,
coors_range, NDim);
};
int hard_voxelize_forward_impl(const at::Tensor& points, at::Tensor& voxels,
at::Tensor& coors,
at::Tensor& num_points_per_voxel,
int hard_voxelize_forward_impl(const at::Tensor &points, at::Tensor &voxels,
at::Tensor &coors,
at::Tensor &num_points_per_voxel,
const std::vector<float> voxel_size,
const std::vector<float> coors_range,
const int max_points, const int max_voxels,
const int NDim);
int nondeterministic_hard_voxelize_forward_impl(
const at::Tensor& points, at::Tensor& voxels, at::Tensor& coors,
at::Tensor& num_points_per_voxel, const std::vector<float> voxel_size,
const at::Tensor &points, at::Tensor &voxels, at::Tensor &coors,
at::Tensor &num_points_per_voxel, const std::vector<float> voxel_size,
const std::vector<float> coors_range, const int max_points,
const int max_voxels, const int NDim);
void dynamic_voxelize_forward_impl(const at::Tensor& points, at::Tensor& coors,
void dynamic_voxelize_forward_impl(const at::Tensor &points, at::Tensor &coors,
const std::vector<float> voxel_size,
const std::vector<float> coors_range,
const int NDim);
......
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
#include <c10/util/Half.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
#include <cstdint>
#include "pytorch_cuda_helper.hpp"
//------------------------------------------------------------------------
// CUDA kernel parameters.
struct filtered_lrelu_kernel_params {
// These parameters decide which kernel to use.
int up; // upsampling ratio (1, 2, 4)
int down; // downsampling ratio (1, 2, 4)
int2 fuShape; // [size, 1] | [size, size]
int2 fdShape; // [size, 1] | [size, size]
int _dummy; // Alignment.
// Rest of the parameters.
const void* x; // Input tensor.
void* y; // Output tensor.
const void* b; // Bias tensor.
unsigned char* s; // Sign tensor in/out. NULL if unused.
const float* fu; // Upsampling filter.
const float* fd; // Downsampling filter.
int2 pad0; // Left/top padding.
float gain; // Additional gain factor.
float slope; // Leaky ReLU slope on negative side.
float clamp; // Clamp after nonlinearity.
int flip; // Filter kernel flip for gradient computation.
int tilesXdim; // Original number of horizontal output tiles.
int tilesXrep; // Number of horizontal tiles per CTA.
int blockZofs; // Block z offset to support large minibatch, channel
// dimensions.
int4 xShape; // [width, height, channel, batch]
int4 yShape; // [width, height, channel, batch]
int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if
// unused.
int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
int swLimit; // Active width of sign tensor in bytes.
longlong4 xStride; // Strides of all tensors except signs, same component
// order as shapes.
longlong4 yStride; //
int64_t bStride; //
longlong3 fuStride; //
longlong3 fdStride; //
};
struct filtered_lrelu_act_kernel_params {
void* x; // Input/output, modified in-place.
unsigned char* s; // Sign tensor in/out. NULL if unused.
float gain; // Additional gain factor.
float slope; // Leaky ReLU slope on negative side.
float clamp; // Clamp after nonlinearity.
int4 xShape; // [width, height, channel, batch]
longlong4 xStride; // Input/output tensor strides, same order as in shape.
int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if
// unused.
int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
};
//------------------------------------------------------------------------
// CUDA kernel specialization.
struct filtered_lrelu_kernel_spec {
void* setup; // Function for filter kernel setup.
void* exec; // Function for main operation.
int2 tileOut; // Width/height of launch tile.
int numWarps; // Number of warps per thread block, determines launch block
// size.
int xrep; // For processing multiple horizontal tiles per thread block.
int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants.
};
//------------------------------------------------------------------------
// CUDA kernel selection.
template <class T, class index_t, bool signWrite, bool signRead>
filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(
const filtered_lrelu_kernel_params& p, int sharedKB);
template <class T, bool signWrite, bool signRead>
void* choose_filtered_lrelu_act_kernel(void);
//------------------------------------------------------------------------
// Helpers.
enum // Filter modes.
{ MODE_SUSD = 0, // Separable upsampling, separable downsampling.
MODE_FUSD = 1, // Full upsampling, separable downsampling.
MODE_SUFD = 2, // Separable upsampling, full downsampling.
MODE_FUFD = 3, // Full upsampling, full downsampling.
};
template <class T>
struct InternalType;
template <>
struct InternalType<double> {
typedef double scalar_t;
typedef double2 vec2_t;
typedef double4 vec4_t;
__device__ __forceinline__ static vec2_t zero_vec2(void) {
return make_double2(0, 0);
}
__device__ __forceinline__ static vec4_t zero_vec4(void) {
return make_double4(0, 0, 0, 0);
}
__device__ __forceinline__ static double clamp(double x, double c) {
return fmin(fmax(x, -c), c);
}
};
template <>
struct InternalType<float> {
typedef float scalar_t;
typedef float2 vec2_t;
typedef float4 vec4_t;
__device__ __forceinline__ static vec2_t zero_vec2(void) {
return make_float2(0, 0);
}
__device__ __forceinline__ static vec4_t zero_vec4(void) {
return make_float4(0, 0, 0, 0);
}
__device__ __forceinline__ static float clamp(float x, float c) {
return fminf(fmaxf(x, -c), c);
}
};
template <>
struct InternalType<c10::Half> {
typedef float scalar_t;
typedef float2 vec2_t;
typedef float4 vec4_t;
__device__ __forceinline__ static vec2_t zero_vec2(void) {
return make_float2(0, 0);
}
__device__ __forceinline__ static vec4_t zero_vec4(void) {
return make_float4(0, 0, 0, 0);
}
__device__ __forceinline__ static float clamp(float x, float c) {
return fminf(fmaxf(x, -c), c);
}
};
#define MIN(A, B) ((A) < (B) ? (A) : (B))
#define MAX(A, B) ((A) > (B) ? (A) : (B))
#define CEIL_DIV(A, B) \
(((B) == 1) \
? (A) \
: ((B) == 2) ? ((int)((A) + 1) >> 1) \
: ((B) == 4) ? ((int)((A) + 3) >> 2) \
: (((A) + ((A) > 0 ? (B)-1 : 0)) / (B)))
// This works only up to blocks of size 256 x 256 and for all N that are powers
// of two.
template <int N>
__device__ __forceinline__ void fast_div_mod(int& x, int& y, unsigned int i) {
if ((N & (N - 1)) && N <= 256)
y = (i * ((1 << 24) / N + 1)) >> 24; // Assumes N <= 256, i < N*256.
else
y = i / N;
x = i - y * N;
}
// Type cast stride before reading it.
template <class T>
__device__ __forceinline__ T get_stride(const int64_t& x) {
return *reinterpret_cast<const T*>(&x);
}
//------------------------------------------------------------------------
// Filters, setup kernel, copying function.
#define MAX_FILTER_SIZE 32
// Combined up/down filter buffers so that transfer can be done with one copy.
__device__ float
g_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in global memory,
// written by setup kernel.
__device__ __constant__ float
c_fbuf[2 * MAX_FILTER_SIZE *
MAX_FILTER_SIZE]; // Filters in constant memory, read by main
// kernel.
// Accessors to combined buffers to index up/down filters individually.
#define c_fu (c_fbuf)
#define c_fd (c_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE)
#define g_fu (g_fbuf)
#define g_fd (g_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE)
// Set up filters into global memory buffer.
static __global__ void setup_filters_kernel(filtered_lrelu_kernel_params p) {
for (int idx = threadIdx.x; idx < MAX_FILTER_SIZE * MAX_FILTER_SIZE;
idx += blockDim.x) {
int x, y;
fast_div_mod<MAX_FILTER_SIZE>(x, y, idx);
int fu_x = p.flip ? x : (p.fuShape.x - 1 - x);
int fu_y = p.flip ? y : (p.fuShape.y - 1 - y);
if (p.fuShape.y > 0)
g_fu[idx] = (x >= p.fuShape.x || y >= p.fuShape.y)
? 0.0f
: p.fu[fu_x * p.fuStride.x + fu_y * p.fuStride.y];
else
g_fu[idx] =
(x >= p.fuShape.x || y > 0) ? 0.0f : p.fu[fu_x * p.fuStride.x];
int fd_x = p.flip ? x : (p.fdShape.x - 1 - x);
int fd_y = p.flip ? y : (p.fdShape.y - 1 - y);
if (p.fdShape.y > 0)
g_fd[idx] = (x >= p.fdShape.x || y >= p.fdShape.y)
? 0.0f
: p.fd[fd_x * p.fdStride.x + fd_y * p.fdStride.y];
else
g_fd[idx] =
(x >= p.fdShape.x || y > 0) ? 0.0f : p.fd[fd_x * p.fdStride.x];
}
}
// Host function to copy filters written by setup kernel into constant buffer
// for main kernel.
static cudaError_t copy_filters(cudaStream_t stream) {
void* src = 0;
cudaError_t err = cudaGetSymbolAddress(&src, g_fbuf);
if (err) return err;
return cudaMemcpyToSymbolAsync(
c_fbuf, src, 2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE * sizeof(float), 0,
cudaMemcpyDeviceToDevice, stream);
}
//------------------------------------------------------------------------
// Coordinate spaces:
// - Relative to input tensor: inX, inY, tileInX, tileInY
// - Relative to input tile: relInX, relInY, tileInW, tileInH
// - Relative to upsampled tile: relUpX, relUpY, tileUpW, tileUpH
// - Relative to output tile: relOutX, relOutY, tileOutW, tileOutH
// - Relative to output tensor: outX, outY, tileOutX, tileOutY
//
// Relationships between coordinate spaces:
// - inX = tileInX + relInX
// - inY = tileInY + relInY
// - relUpX = relInX * up + phaseInX
// - relUpY = relInY * up + phaseInY
// - relUpX = relOutX * down
// - relUpY = relOutY * down
// - outX = tileOutX + relOutX
// - outY = tileOutY + relOutY
extern __shared__ char
s_buf_raw[]; // When sharedKB <= 48, allocate shared memory statically
// inside the kernel, otherwise use the externally allocated
// shared memory buffer.
template <class T, class index_t, int sharedKB, bool signWrite, bool signRead,
int filterMode, int up, int fuSize, int down, int fdSize,
int tileOutW, int tileOutH, int threadsPerBlock, bool enableXrep,
bool enableWriteSkip>
static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
// Check that we don't try to support non-existing filter modes.
static_assert(up == 1 || up == 2 || up == 4,
"only up=1, up=2, up=4 scales supported");
static_assert(down == 1 || down == 2 || down == 4,
"only down=1, down=2, down=4 scales supported");
static_assert(fuSize >= up,
"upsampling filter size must be at least upsampling factor");
static_assert(
fdSize >= down,
"downsampling filter size must be at least downsampling factor");
static_assert(
fuSize % up == 0,
"upsampling filter size must be divisible with upsampling factor");
static_assert(
fdSize % down == 0,
"downsampling filter size must be divisible with downsampling factor");
static_assert(fuSize <= MAX_FILTER_SIZE && fdSize <= MAX_FILTER_SIZE,
"filter size greater than MAX_FILTER_SIZE");
static_assert(up != 1 || (fuSize == 1 && (filterMode == MODE_FUFD ||
filterMode == MODE_FUSD)),
"up=1 supported only for 1x1 full filters");
static_assert(down != 1 || (fdSize == 1 && (filterMode == MODE_FUFD ||
filterMode == MODE_SUFD)),
"down=1 supported only for 1x1 full filters");
static_assert(
!(up == 4 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)),
"full filters not supported for up=4");
static_assert(
!(down == 4 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)),
"full filters not supported for down=4");
// Static definitions.
typedef typename InternalType<T>::scalar_t scalar_t;
typedef typename InternalType<T>::vec2_t vec2_t;
typedef typename InternalType<T>::vec4_t vec4_t;
const int tileUpW = (tileOutW * down + (fdSize - 1) - (down - 1) + 3) &
~3; // Upsampled tile width, rounded up to multiple of 4.
const int tileUpH =
tileOutH * down + (fdSize - 1) - (down - 1); // Upsampled tile height.
const int tileInW =
CEIL_DIV(tileUpW + (fuSize - 1), up); // Input tile width.
const int tileInH =
CEIL_DIV(tileUpH + (fuSize - 1), up); // Input tile height.
const int tileUpH_up =
CEIL_DIV(tileUpH, up) *
up; // Upsampled tile height rounded up to a multiple of up.
const int tileInH_up =
CEIL_DIV(tileUpH_up + (fuSize - 1),
up); // For allocations only, to avoid shared memory read
// overruns with up=2 and up=4.
// Merge 1x1 downsampling into last upsampling step for upf1 and ups2.
const bool downInline =
(down == 1) && ((up == 1 && filterMode == MODE_FUFD) ||
(up == 2 && filterMode == MODE_SUFD));
// Sizes of logical buffers.
const int szIn = tileInH_up * tileInW;
const int szUpX = tileInH_up * tileUpW;
const int szUpXY = downInline ? 0 : (tileUpH * tileUpW);
const int szDownX = tileUpH * tileOutW;
// Sizes for shared memory arrays.
const int s_buf0_size_base =
(filterMode == MODE_SUSD)
? MAX(szIn, szUpXY)
: (filterMode == MODE_FUSD)
? MAX(szIn, szDownX)
: (filterMode == MODE_SUFD)
? MAX(szIn, szUpXY)
: (filterMode == MODE_FUFD) ? szIn : -1;
const int s_buf1_size_base =
(filterMode == MODE_SUSD)
? MAX(szUpX, szDownX)
: (filterMode == MODE_FUSD)
? szUpXY
: (filterMode == MODE_SUFD)
? szUpX
: (filterMode == MODE_FUFD) ? szUpXY : -1;
// Ensure U128 alignment.
const int s_buf0_size = (s_buf0_size_base + 3) & ~3;
const int s_buf1_size = (s_buf1_size_base + 3) & ~3;
// Check at compile time that we don't use too much shared memory.
static_assert(
(s_buf0_size + s_buf1_size) * sizeof(scalar_t) <= (sharedKB << 10),
"shared memory overflow");
// Declare shared memory arrays.
scalar_t* s_buf0;
scalar_t* s_buf1;
if (sharedKB <= 48) {
// Allocate shared memory arrays here.
__shared__ scalar_t
s_buf0_st[(sharedKB > 48)
? (1 << 24)
: (s_buf0_size +
s_buf1_size)]; // Prevent launching if this isn't
// optimized away when unused.
s_buf0 = s_buf0_st;
s_buf1 = s_buf0 + s_buf0_size;
} else {
// Use the dynamically allocated shared memory array.
s_buf0 = (scalar_t*)s_buf_raw;
s_buf1 = s_buf0 + s_buf0_size;
}
// Pointers to the buffers.
scalar_t*
s_tileIn; // Input tile: [relInX * tileInH + relInY]
scalar_t* s_tileUpX; // After horizontal upsampling: [relInY * tileUpW +
// relUpX]
scalar_t* s_tileUpXY; // After upsampling: [relUpY * tileUpW +
// relUpX]
scalar_t* s_tileDownX; // After horizontal downsampling: [relUpY * tileOutW
// + relOutX]
if (filterMode == MODE_SUSD) {
s_tileIn = s_buf0;
s_tileUpX = s_buf1;
s_tileUpXY = s_buf0;
s_tileDownX = s_buf1;
} else if (filterMode == MODE_FUSD) {
s_tileIn = s_buf0;
s_tileUpXY = s_buf1;
s_tileDownX = s_buf0;
} else if (filterMode == MODE_SUFD) {
s_tileIn = s_buf0;
s_tileUpX = s_buf1;
s_tileUpXY = s_buf0;
} else if (filterMode == MODE_FUFD) {
s_tileIn = s_buf0;
s_tileUpXY = s_buf1;
}
// Allow large grids in z direction via per-launch offset.
int channelIdx = blockIdx.z + p.blockZofs;
int batchIdx = channelIdx / p.yShape.z;
channelIdx -= batchIdx * p.yShape.z;
// Offset to output feature map. In bytes.
index_t mapOfsOut = channelIdx * get_stride<index_t>(p.yStride.z) +
batchIdx * get_stride<index_t>(p.yStride.w);
// Sign shift amount.
uint32_t signXo = ((threadIdx.x + p.sOfs.x) << 1) & 6;
// Inner tile loop.
#pragma unroll 1
for (int tileIdx = 0;
!enableXrep ||
(tileIdx < MIN(p.tilesXrep, p.tilesXdim - p.tilesXrep * blockIdx.y));
tileIdx++) {
// Locate output tile.
int tileX = enableXrep ? blockIdx.y * p.tilesXrep + tileIdx : blockIdx.x;
int tileOutX = tileX * tileOutW;
int tileOutY = (enableXrep ? blockIdx.x : blockIdx.y) * tileOutH;
// Locate input tile.
int tmpX = tileOutX * down - p.pad0.x;
int tmpY = tileOutY * down - p.pad0.y;
int tileInX = CEIL_DIV(tmpX, up);
int tileInY = CEIL_DIV(tmpY, up);
const int phaseInX = tileInX * up - tmpX;
const int phaseInY = tileInY * up - tmpY;
// Extra sync if input and output buffers are the same and we are not on
// first tile.
if (enableXrep && tileIdx > 0 &&
(filterMode == MODE_FUSD || (filterMode == MODE_SUFD && !downInline) ||
(filterMode == MODE_FUFD && downInline)))
__syncthreads();
// Load input tile & apply bias. Unrolled.
scalar_t b =
(scalar_t) * (const T*)((const char*)p.b +
(channelIdx * get_stride<index_t>(p.bStride)));
index_t mapOfsIn = channelIdx * get_stride<index_t>(p.xStride.z) +
batchIdx * get_stride<index_t>(p.xStride.w);
int idx = threadIdx.x;
const int loopCountIN = CEIL_DIV(tileInW * tileInH, threadsPerBlock);
#pragma unroll
for (int loop = 0; loop < loopCountIN; loop++) {
int relInX, relInY;
fast_div_mod<tileInW>(relInX, relInY, idx);
int inX = tileInX + relInX;
int inY = tileInY + relInY;
scalar_t v = 0;
if ((uint32_t)inX < p.xShape.x && (uint32_t)inY < p.xShape.y)
v = (scalar_t) * ((const T*)((const char*)p.x +
(inX * get_stride<index_t>(p.xStride.x) +
inY * get_stride<index_t>(p.xStride.y) +
mapOfsIn))) +
b;
bool skip = (loop == loopCountIN - 1) && (idx >= tileInW * tileInH);
if (!skip) s_tileIn[idx] = v;
idx += threadsPerBlock;
}
if (filterMode == MODE_SUSD ||
filterMode == MODE_SUFD) // Separable upsampling filter.
{
// Horizontal upsampling.
__syncthreads();
if (up == 4) {
for (int idx = threadIdx.x * up; idx < tileUpW * tileInH;
idx += blockDim.x * up) {
int relUpX0, relInY;
fast_div_mod<tileUpW>(relUpX0, relInY, idx);
int relInX0 = relUpX0 / up;
int src0 = relInX0 + tileInW * relInY;
int dst = relInY * tileUpW + relUpX0;
vec4_t v = InternalType<T>::zero_vec4();
scalar_t a = s_tileIn[src0];
if (phaseInX == 0) {
#pragma unroll
for (int step = 0; step < fuSize / up; step++) {
v.x += a * (scalar_t)c_fu[step * up + 0];
a = s_tileIn[src0 + step + 1];
v.y += a * (scalar_t)c_fu[step * up + 3];
v.z += a * (scalar_t)c_fu[step * up + 2];
v.w += a * (scalar_t)c_fu[step * up + 1];
}
} else if (phaseInX == 1) {
#pragma unroll
for (int step = 0; step < fuSize / up; step++) {
v.x += a * (scalar_t)c_fu[step * up + 1];
v.y += a * (scalar_t)c_fu[step * up + 0];
a = s_tileIn[src0 + step + 1];
v.z += a * (scalar_t)c_fu[step * up + 3];
v.w += a * (scalar_t)c_fu[step * up + 2];
}
} else if (phaseInX == 2) {
#pragma unroll
for (int step = 0; step < fuSize / up; step++) {
v.x += a * (scalar_t)c_fu[step * up + 2];
v.y += a * (scalar_t)c_fu[step * up + 1];
v.z += a * (scalar_t)c_fu[step * up + 0];
a = s_tileIn[src0 + step + 1];
v.w += a * (scalar_t)c_fu[step * up + 3];
}
} else // (phaseInX == 3)
{
#pragma unroll
for (int step = 0; step < fuSize / up; step++) {
v.x += a * (scalar_t)c_fu[step * up + 3];
v.y += a * (scalar_t)c_fu[step * up + 2];
v.z += a * (scalar_t)c_fu[step * up + 1];
v.w += a * (scalar_t)c_fu[step * up + 0];
a = s_tileIn[src0 + step + 1];
}
}
s_tileUpX[dst + 0] = v.x;
s_tileUpX[dst + 1] = v.y;
s_tileUpX[dst + 2] = v.z;
s_tileUpX[dst + 3] = v.w;
}
} else if (up == 2) {
bool p0 = (phaseInX == 0);
for (int idx = threadIdx.x * up; idx < tileUpW * tileInH;
idx += blockDim.x * up) {
int relUpX0, relInY;
fast_div_mod<tileUpW>(relUpX0, relInY, idx);
int relInX0 = relUpX0 / up;
int src0 = relInX0 + tileInW * relInY;
int dst = relInY * tileUpW + relUpX0;
vec2_t v = InternalType<T>::zero_vec2();
scalar_t a = s_tileIn[src0];
if (p0) // (phaseInX == 0)
{
#pragma unroll
for (int step = 0; step < fuSize / up; step++) {
v.x += a * (scalar_t)c_fu[step * up + 0];
a = s_tileIn[src0 + step + 1];
v.y += a * (scalar_t)c_fu[step * up + 1];
}
} else // (phaseInX == 1)
{
#pragma unroll
for (int step = 0; step < fuSize / up; step++) {
v.x += a * (scalar_t)c_fu[step * up + 1];
v.y += a * (scalar_t)c_fu[step * up + 0];
a = s_tileIn[src0 + step + 1];
}
}
s_tileUpX[dst + 0] = v.x;
s_tileUpX[dst + 1] = v.y;
}
}
// Vertical upsampling & nonlinearity.
__syncthreads();
int groupMask = 15 << ((threadIdx.x & 31) & ~3);
int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH
: 0; // Skip already written signs.
int sShapeMaxY =
MIN(p.sShape.y,
tileOutY * down + tileUpH); // Avoid out-of-tile sign writes.
if (up == 4) {
minY -= 3; // Adjust according to block height.
for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up;
idx += blockDim.x) {
int relUpX, relInY0;
fast_div_mod<tileUpW>(relUpX, relInY0, idx);
int relUpY0 = relInY0 * up;
int src0 = relInY0 * tileUpW + relUpX;
int dst = relUpY0 * tileUpW + relUpX;
vec4_t v = InternalType<T>::zero_vec4();
scalar_t a = s_tileUpX[src0];
if (phaseInY == 0) {
#pragma unroll
for (int step = 0; step < fuSize / up; step++) {
v.x += a * (scalar_t)c_fu[step * up + 0];
a = s_tileUpX[src0 + (step + 1) * tileUpW];
v.y += a * (scalar_t)c_fu[step * up + 3];
v.z += a * (scalar_t)c_fu[step * up + 2];
v.w += a * (scalar_t)c_fu[step * up + 1];
}
} else if (phaseInY == 1) {
#pragma unroll
for (int step = 0; step < fuSize / up; step++) {
v.x += a * (scalar_t)c_fu[step * up + 1];
v.y += a * (scalar_t)c_fu[step * up + 0];
a = s_tileUpX[src0 + (step + 1) * tileUpW];
v.z += a * (scalar_t)c_fu[step * up + 3];
v.w += a * (scalar_t)c_fu[step * up + 2];
}
} else if (phaseInY == 2) {
#pragma unroll
for (int step = 0; step < fuSize / up; step++) {
v.x += a * (scalar_t)c_fu[step * up + 2];
v.y += a * (scalar_t)c_fu[step * up + 1];
v.z += a * (scalar_t)c_fu[step * up + 0];
a = s_tileUpX[src0 + (step + 1) * tileUpW];
v.w += a * (scalar_t)c_fu[step * up + 3];
}
} else // (phaseInY == 3)
{
#pragma unroll
for (int step = 0; step < fuSize / up; step++) {
v.x += a * (scalar_t)c_fu[step * up + 3];
v.y += a * (scalar_t)c_fu[step * up + 2];
v.z += a * (scalar_t)c_fu[step * up + 1];
v.w += a * (scalar_t)c_fu[step * up + 0];
a = s_tileUpX[src0 + (step + 1) * tileUpW];
}
}
int x = tileOutX * down + relUpX;
int y = tileOutY * down + relUpY0;
int signX = x + p.sOfs.x;
int signY = y + p.sOfs.y;
int signZ = blockIdx.z + p.blockZofs;
int signXb = signX >> 2;
index_t si0 =
signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
index_t si1 = si0 + p.sShape.x;
index_t si2 = si0 + p.sShape.x * 2;
index_t si3 = si0 + p.sShape.x * 3;
v.x *= (scalar_t)((float)up * (float)up * p.gain);
v.y *= (scalar_t)((float)up * (float)up * p.gain);
v.z *= (scalar_t)((float)up * (float)up * p.gain);
v.w *= (scalar_t)((float)up * (float)up * p.gain);
if (signWrite) {
if (!enableWriteSkip) {
// Determine and write signs.
int sx = __float_as_uint(v.x) >> 31 << 0;
int sy = __float_as_uint(v.y) >> 31 << 8;
int sz = __float_as_uint(v.z) >> 31 << 16;
int sw = __float_as_uint(v.w) >> 31 << 24;
if (sx) v.x *= p.slope;
if (sy) v.y *= p.slope;
if (sz) v.z *= p.slope;
if (sw) v.w *= p.slope;
if (fabsf(v.x) > p.clamp) {
sx = 2 << 0;
v.x = InternalType<T>::clamp(v.x, p.clamp);
}
if (fabsf(v.y) > p.clamp) {
sy = 2 << 8;
v.y = InternalType<T>::clamp(v.y, p.clamp);
}
if (fabsf(v.z) > p.clamp) {
sz = 2 << 16;
v.z = InternalType<T>::clamp(v.z, p.clamp);
}
if (fabsf(v.w) > p.clamp) {
sw = 2 << 24;
v.w = InternalType<T>::clamp(v.w, p.clamp);
}
if ((uint32_t)signXb < p.swLimit && signY >= minY) {
// Combine signs.
uint32_t s = sx + sy + sw + sz;
s <<= (signX & 3) << 1;
s |= __shfl_xor_sync(groupMask, s, 1);
s |= __shfl_xor_sync(groupMask, s, 2);
// Write signs.
if ((uint32_t)(signY + 0) < sShapeMaxY) {
p.s[si0] = (unsigned char)(s >> 0);
}
if ((uint32_t)(signY + 1) < sShapeMaxY) {
p.s[si1] = (unsigned char)(s >> 8);
}
if ((uint32_t)(signY + 2) < sShapeMaxY) {
p.s[si2] = (unsigned char)(s >> 16);
}
if ((uint32_t)(signY + 3) < sShapeMaxY) {
p.s[si3] = (unsigned char)(s >> 24);
}
}
} else {
// Determine and write signs.
if ((uint32_t)signXb < p.swLimit && signY >= minY) {
int sx = __float_as_uint(v.x) >> 31 << 0;
int sy = __float_as_uint(v.y) >> 31 << 8;
int sz = __float_as_uint(v.z) >> 31 << 16;
int sw = __float_as_uint(v.w) >> 31 << 24;
if (sx) v.x *= p.slope;
if (sy) v.y *= p.slope;
if (sz) v.z *= p.slope;
if (sw) v.w *= p.slope;
if (fabsf(v.x) > p.clamp) {
sx = 2 << 0;
v.x = InternalType<T>::clamp(v.x, p.clamp);
}
if (fabsf(v.y) > p.clamp) {
sy = 2 << 8;
v.y = InternalType<T>::clamp(v.y, p.clamp);
}
if (fabsf(v.z) > p.clamp) {
sz = 2 << 16;
v.z = InternalType<T>::clamp(v.z, p.clamp);
}
if (fabsf(v.w) > p.clamp) {
sw = 2 << 24;
v.w = InternalType<T>::clamp(v.w, p.clamp);
}
// Combine signs.
uint32_t s = sx + sy + sw + sz;
s <<= (signX & 3) << 1;
s |= __shfl_xor_sync(groupMask, s, 1);
s |= __shfl_xor_sync(groupMask, s, 2);
// Write signs.
if ((uint32_t)(signY + 0) < sShapeMaxY) {
p.s[si0] = (unsigned char)(s >> 0);
}
if ((uint32_t)(signY + 1) < sShapeMaxY) {
p.s[si1] = (unsigned char)(s >> 8);
}
if ((uint32_t)(signY + 2) < sShapeMaxY) {
p.s[si2] = (unsigned char)(s >> 16);
}
if ((uint32_t)(signY + 3) < sShapeMaxY) {
p.s[si3] = (unsigned char)(s >> 24);
}
} else {
// Just compute the values.
if (v.x < 0.f) v.x *= p.slope;
v.x = InternalType<T>::clamp(v.x, p.clamp);
if (v.y < 0.f) v.y *= p.slope;
v.y = InternalType<T>::clamp(v.y, p.clamp);
if (v.z < 0.f) v.z *= p.slope;
v.z = InternalType<T>::clamp(v.z, p.clamp);
if (v.w < 0.f) v.w *= p.slope;
v.w = InternalType<T>::clamp(v.w, p.clamp);
}
}
} else if (signRead) // Read signs and apply.
{
if ((uint32_t)signXb < p.swLimit) {
int ss = (signX & 3) << 1;
if ((uint32_t)(signY + 0) < p.sShape.y) {
int s = p.s[si0] >> ss;
if (s & 1) v.x *= p.slope;
if (s & 2) v.x = 0.f;
}
if ((uint32_t)(signY + 1) < p.sShape.y) {
int s = p.s[si1] >> ss;
if (s & 1) v.y *= p.slope;
if (s & 2) v.y = 0.f;
}
if ((uint32_t)(signY + 2) < p.sShape.y) {
int s = p.s[si2] >> ss;
if (s & 1) v.z *= p.slope;
if (s & 2) v.z = 0.f;
}
if ((uint32_t)(signY + 3) < p.sShape.y) {
int s = p.s[si3] >> ss;
if (s & 1) v.w *= p.slope;
if (s & 2) v.w = 0.f;
}
}
} else // Forward pass with no sign write.
{
if (v.x < 0.f) v.x *= p.slope;
v.x = InternalType<T>::clamp(v.x, p.clamp);
if (v.y < 0.f) v.y *= p.slope;
v.y = InternalType<T>::clamp(v.y, p.clamp);
if (v.z < 0.f) v.z *= p.slope;
v.z = InternalType<T>::clamp(v.z, p.clamp);
if (v.w < 0.f) v.w *= p.slope;
v.w = InternalType<T>::clamp(v.w, p.clamp);
}
s_tileUpXY[dst + 0 * tileUpW] = v.x;
if (relUpY0 + 1 < tileUpH) s_tileUpXY[dst + 1 * tileUpW] = v.y;
if (relUpY0 + 2 < tileUpH) s_tileUpXY[dst + 2 * tileUpW] = v.z;
if (relUpY0 + 3 < tileUpH) s_tileUpXY[dst + 3 * tileUpW] = v.w;
}
} else if (up == 2) {
minY -= 1; // Adjust according to block height.
for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up;
idx += blockDim.x) {
int relUpX, relInY0;
fast_div_mod<tileUpW>(relUpX, relInY0, idx);
int relUpY0 = relInY0 * up;
int src0 = relInY0 * tileUpW + relUpX;
int dst = relUpY0 * tileUpW + relUpX;
vec2_t v = InternalType<T>::zero_vec2();
scalar_t a = s_tileUpX[src0];
if (phaseInY == 0) {
#pragma unroll
for (int step = 0; step < fuSize / up; step++) {
v.x += a * (scalar_t)c_fu[step * up + 0];
a = s_tileUpX[src0 + (step + 1) * tileUpW];
v.y += a * (scalar_t)c_fu[step * up + 1];
}
} else // (phaseInY == 1)
{
#pragma unroll
for (int step = 0; step < fuSize / up; step++) {
v.x += a * (scalar_t)c_fu[step * up + 1];
v.y += a * (scalar_t)c_fu[step * up + 0];
a = s_tileUpX[src0 + (step + 1) * tileUpW];
}
}
int x = tileOutX * down + relUpX;
int y = tileOutY * down + relUpY0;
int signX = x + p.sOfs.x;
int signY = y + p.sOfs.y;
int signZ = blockIdx.z + p.blockZofs;
int signXb = signX >> 2;
index_t si0 =
signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
index_t si1 = si0 + p.sShape.x;
v.x *= (scalar_t)((float)up * (float)up * p.gain);
v.y *= (scalar_t)((float)up * (float)up * p.gain);
if (signWrite) {
if (!enableWriteSkip) {
// Determine and write signs.
int sx = __float_as_uint(v.x) >> 31 << 0;
int sy = __float_as_uint(v.y) >> 31 << 8;
if (sx) v.x *= p.slope;
if (sy) v.y *= p.slope;
if (fabsf(v.x) > p.clamp) {
sx = 2 << 0;
v.x = InternalType<T>::clamp(v.x, p.clamp);
}
if (fabsf(v.y) > p.clamp) {
sy = 2 << 8;
v.y = InternalType<T>::clamp(v.y, p.clamp);
}
if ((uint32_t)signXb < p.swLimit && signY >= minY) {
// Combine signs.
int s = sx + sy;
s <<= signXo;
s |= __shfl_xor_sync(groupMask, s, 1);
s |= __shfl_xor_sync(groupMask, s, 2);
// Write signs.
if ((uint32_t)(signY + 0) < sShapeMaxY) {
p.s[si0] = (unsigned char)(s >> 0);
}
if ((uint32_t)(signY + 1) < sShapeMaxY) {
p.s[si1] = (unsigned char)(s >> 8);
}
}
} else {
// Determine and write signs.
if ((uint32_t)signXb < p.swLimit && signY >= minY) {
int sx = __float_as_uint(v.x) >> 31 << 0;
int sy = __float_as_uint(v.y) >> 31 << 8;
if (sx) v.x *= p.slope;
if (sy) v.y *= p.slope;
if (fabsf(v.x) > p.clamp) {
sx = 2 << 0;
v.x = InternalType<T>::clamp(v.x, p.clamp);
}
if (fabsf(v.y) > p.clamp) {
sy = 2 << 8;
v.y = InternalType<T>::clamp(v.y, p.clamp);
}
// Combine signs.
int s = sx + sy;
s <<= signXo;
s |= __shfl_xor_sync(groupMask, s, 1);
s |= __shfl_xor_sync(groupMask, s, 2);
// Write signs.
if ((uint32_t)(signY + 0) < sShapeMaxY) {
p.s[si0] = (unsigned char)(s >> 0);
}
if ((uint32_t)(signY + 1) < sShapeMaxY) {
p.s[si1] = (unsigned char)(s >> 8);
}
} else {
// Just compute the values.
if (v.x < 0.f) v.x *= p.slope;
v.x = InternalType<T>::clamp(v.x, p.clamp);
if (v.y < 0.f) v.y *= p.slope;
v.y = InternalType<T>::clamp(v.y, p.clamp);
}
}
} else if (signRead) // Read signs and apply.
{
if ((uint32_t)signXb < p.swLimit) {
if ((uint32_t)(signY + 0) < p.sShape.y) {
int s = p.s[si0] >> signXo;
if (s & 1) v.x *= p.slope;
if (s & 2) v.x = 0.f;
}
if ((uint32_t)(signY + 1) < p.sShape.y) {
int s = p.s[si1] >> signXo;
if (s & 1) v.y *= p.slope;
if (s & 2) v.y = 0.f;
}
}
} else // Forward pass with no sign write.
{
if (v.x < 0.f) v.x *= p.slope;
v.x = InternalType<T>::clamp(v.x, p.clamp);
if (v.y < 0.f) v.y *= p.slope;
v.y = InternalType<T>::clamp(v.y, p.clamp);
}
if (!downInline) {
// Write into temporary buffer.
s_tileUpXY[dst] = v.x;
if (relUpY0 < tileUpH - 1) s_tileUpXY[dst + tileUpW] = v.y;
} else {
// Write directly into output buffer.
if ((uint32_t)x < p.yShape.x) {
int ymax = MIN(p.yShape.y, tileUpH + tileOutY * down);
index_t ofs = x * get_stride<index_t>(p.yStride.x) +
y * get_stride<index_t>(p.yStride.y) + mapOfsOut;
if ((uint32_t)y + 0 < p.yShape.y)
*((T*)((char*)p.y + ofs)) = (T)(v.x * (scalar_t)c_fd[0]);
if ((uint32_t)y + 1 < ymax)
*((T*)((char*)p.y + ofs + get_stride<index_t>(p.yStride.y))) =
(T)(v.y * (scalar_t)c_fd[0]);
}
}
}
}
} else if (filterMode == MODE_FUSD || filterMode == MODE_FUFD) {
// Full upsampling filter.
if (up == 2) {
// 2 x 2-wide.
__syncthreads();
int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH + p.sOfs.y
: 0; // Skip already written signs.
for (int idx = threadIdx.x * 4; idx < tileUpW * tileUpH;
idx += blockDim.x * 4) {
int relUpX0, relUpY0;
fast_div_mod<tileUpW>(relUpX0, relUpY0, idx);
int relInX0 = CEIL_DIV(relUpX0 - phaseInX, up);
int relInY0 = CEIL_DIV(relUpY0 - phaseInY, up);
int src0 = relInX0 + tileInW * relInY0;
int tap0y = (relInY0 * up + phaseInY - relUpY0);
#define X_LOOP(TAPY, PX) \
for (int sx = 0; sx < fuSize / up; sx++) { \
v.x += a * (scalar_t)c_fu[(sx * up + (((PX)-0) & (up - 1))) + \
(sy * up + (TAPY)) * MAX_FILTER_SIZE]; \
v.z += b * (scalar_t)c_fu[(sx * up + (((PX)-0) & (up - 1))) + \
(sy * up + (TAPY)) * MAX_FILTER_SIZE]; \
if ((PX) == 0) { \
a = b; \
b = s_tileIn[src0 + 2 + sx + sy * tileInW]; \
} \
v.y += a * (scalar_t)c_fu[(sx * up + (((PX)-1) & (up - 1))) + \
(sy * up + (TAPY)) * MAX_FILTER_SIZE]; \
v.w += b * (scalar_t)c_fu[(sx * up + (((PX)-1) & (up - 1))) + \
(sy * up + (TAPY)) * MAX_FILTER_SIZE]; \
if ((PX) == 1) { \
a = b; \
b = s_tileIn[src0 + 2 + sx + sy * tileInW]; \
} \
}
vec4_t v = InternalType<T>::zero_vec4();
if (tap0y == 0 && phaseInX == 0)
#pragma unroll
for (int sy = 0; sy < fuSize / up; sy++) {
scalar_t a = s_tileIn[src0 + sy * tileInW];
scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
#pragma unroll
X_LOOP(0, 0)
}
if (tap0y == 0 && phaseInX == 1)
#pragma unroll
for (int sy = 0; sy < fuSize / up; sy++) {
scalar_t a = s_tileIn[src0 + sy * tileInW];
scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
#pragma unroll
X_LOOP(0, 1)
}
if (tap0y == 1 && phaseInX == 0)
#pragma unroll
for (int sy = 0; sy < fuSize / up; sy++) {
scalar_t a = s_tileIn[src0 + sy * tileInW];
scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
#pragma unroll
X_LOOP(1, 0)
}
if (tap0y == 1 && phaseInX == 1)
#pragma unroll
for (int sy = 0; sy < fuSize / up; sy++) {
scalar_t a = s_tileIn[src0 + sy * tileInW];
scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
#pragma unroll
X_LOOP(1, 1)
}
#undef X_LOOP
int x = tileOutX * down + relUpX0;
int y = tileOutY * down + relUpY0;
int signX = x + p.sOfs.x;
int signY = y + p.sOfs.y;
int signZ = blockIdx.z + p.blockZofs;
int signXb = signX >> 2;
index_t si =
signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
v.x *= (scalar_t)((float)up * (float)up * p.gain);
v.y *= (scalar_t)((float)up * (float)up * p.gain);
v.z *= (scalar_t)((float)up * (float)up * p.gain);
v.w *= (scalar_t)((float)up * (float)up * p.gain);
if (signWrite) {
if (!enableWriteSkip) {
// Determine and write signs.
int sx = __float_as_uint(v.x) >> 31;
int sy = __float_as_uint(v.y) >> 31;
int sz = __float_as_uint(v.z) >> 31;
int sw = __float_as_uint(v.w) >> 31;
if (sx) v.x *= p.slope;
if (fabsf(v.x) > p.clamp) {
sx = 2;
v.x = InternalType<T>::clamp(v.x, p.clamp);
}
if (sy) v.y *= p.slope;
if (fabsf(v.y) > p.clamp) {
sy = 2;
v.y = InternalType<T>::clamp(v.y, p.clamp);
}
if (sz) v.z *= p.slope;
if (fabsf(v.z) > p.clamp) {
sz = 2;
v.z = InternalType<T>::clamp(v.z, p.clamp);
}
if (sw) v.w *= p.slope;
if (fabsf(v.w) > p.clamp) {
sw = 2;
v.w = InternalType<T>::clamp(v.w, p.clamp);
}
if ((uint32_t)signXb < p.swLimit &&
(uint32_t)signY < p.sShape.y && signY >= minY) {
p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6);
}
} else {
// Determine and write signs.
if ((uint32_t)signXb < p.swLimit &&
(uint32_t)signY < p.sShape.y && signY >= minY) {
int sx = __float_as_uint(v.x) >> 31;
int sy = __float_as_uint(v.y) >> 31;
int sz = __float_as_uint(v.z) >> 31;
int sw = __float_as_uint(v.w) >> 31;
if (sx) v.x *= p.slope;
if (fabsf(v.x) > p.clamp) {
sx = 2;
v.x = InternalType<T>::clamp(v.x, p.clamp);
}
if (sy) v.y *= p.slope;
if (fabsf(v.y) > p.clamp) {
sy = 2;
v.y = InternalType<T>::clamp(v.y, p.clamp);
}
if (sz) v.z *= p.slope;
if (fabsf(v.z) > p.clamp) {
sz = 2;
v.z = InternalType<T>::clamp(v.z, p.clamp);
}
if (sw) v.w *= p.slope;
if (fabsf(v.w) > p.clamp) {
sw = 2;
v.w = InternalType<T>::clamp(v.w, p.clamp);
}
p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6);
} else {
// Just compute the values.
if (v.x < 0.f) v.x *= p.slope;
v.x = InternalType<T>::clamp(v.x, p.clamp);
if (v.y < 0.f) v.y *= p.slope;
v.y = InternalType<T>::clamp(v.y, p.clamp);
if (v.z < 0.f) v.z *= p.slope;
v.z = InternalType<T>::clamp(v.z, p.clamp);
if (v.w < 0.f) v.w *= p.slope;
v.w = InternalType<T>::clamp(v.w, p.clamp);
}
}
} else if (signRead) // Read sign and apply.
{
if ((uint32_t)signY < p.sShape.y) {
int s = 0;
if ((uint32_t)signXb < p.swLimit) s = p.s[si];
if ((uint32_t)signXb + 1 < p.swLimit) s |= p.s[si + 1] << 8;
s >>= (signX & 3) << 1;
if (s & 0x01) v.x *= p.slope;
if (s & 0x02) v.x = 0.f;
if (s & 0x04) v.y *= p.slope;
if (s & 0x08) v.y = 0.f;
if (s & 0x10) v.z *= p.slope;
if (s & 0x20) v.z = 0.f;
if (s & 0x40) v.w *= p.slope;
if (s & 0x80) v.w = 0.f;
}
} else // Forward pass with no sign write.
{
if (v.x < 0.f) v.x *= p.slope;
v.x = InternalType<T>::clamp(v.x, p.clamp);
if (v.y < 0.f) v.y *= p.slope;
v.y = InternalType<T>::clamp(v.y, p.clamp);
if (v.z < 0.f) v.z *= p.slope;
v.z = InternalType<T>::clamp(v.z, p.clamp);
if (v.w < 0.f) v.w *= p.slope;
v.w = InternalType<T>::clamp(v.w, p.clamp);
}
s_tileUpXY[idx + 0] = v.x;
s_tileUpXY[idx + 1] = v.y;
s_tileUpXY[idx + 2] = v.z;
s_tileUpXY[idx + 3] = v.w;
}
} else if (up == 1) {
__syncthreads();
uint32_t groupMask = 15 << ((threadIdx.x & 31) & ~3);
int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH
: 0; // Skip already written signs.
for (int idx = threadIdx.x; idx < tileUpW * tileUpH;
idx += blockDim.x) {
int relUpX0, relUpY0;
fast_div_mod<tileUpW>(relUpX0, relUpY0, idx);
scalar_t v = s_tileIn[idx] * (scalar_t)c_fu[0]; // 1x1 filter.
int x = tileOutX * down + relUpX0;
int y = tileOutY * down + relUpY0;
int signX = x + p.sOfs.x;
int signY = y + p.sOfs.y;
int signZ = blockIdx.z + p.blockZofs;
int signXb = signX >> 2;
index_t si =
signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
v *= (scalar_t)((float)up * (float)up * p.gain);
if (signWrite) {
if (!enableWriteSkip) {
// Determine and write sign.
uint32_t s = 0;
uint32_t signXbit = (1u << signXo);
if (v < 0.f) {
s = signXbit;
v *= p.slope;
}
if (fabsf(v) > p.clamp) {
s = signXbit * 2;
v = InternalType<T>::clamp(v, p.clamp);
}
if ((uint32_t)signXb < p.swLimit &&
(uint32_t)signY < p.sShape.y && signY >= minY) {
s += __shfl_xor_sync(groupMask, s, 1); // Coalesce.
s += __shfl_xor_sync(groupMask, s, 2); // Coalesce.
p.s[si] = s; // Write.
}
} else {
// Determine and write sign.
if ((uint32_t)signXb < p.swLimit &&
(uint32_t)signY < p.sShape.y && signY >= minY) {
uint32_t s = 0;
uint32_t signXbit = (1u << signXo);
if (v < 0.f) {
s = signXbit;
v *= p.slope;
}
if (fabsf(v) > p.clamp) {
s = signXbit * 2;
v = InternalType<T>::clamp(v, p.clamp);
}
s += __shfl_xor_sync(groupMask, s, 1); // Coalesce.
s += __shfl_xor_sync(groupMask, s, 2); // Coalesce.
p.s[si] = s; // Write.
} else {
// Just compute the value.
if (v < 0.f) v *= p.slope;
v = InternalType<T>::clamp(v, p.clamp);
}
}
} else if (signRead) {
// Read sign and apply if within sign tensor bounds.
if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y) {
int s = p.s[si];
s >>= signXo;
if (s & 1) v *= p.slope;
if (s & 2) v = 0.f;
}
} else // Forward pass with no sign write.
{
if (v < 0.f) v *= p.slope;
v = InternalType<T>::clamp(v, p.clamp);
}
if (!downInline) // Write into temporary buffer.
s_tileUpXY[idx] = v;
else if ((uint32_t)x < p.yShape.x &&
(uint32_t)y <
p.yShape.y) // Write directly into output buffer
*((T*)((char*)p.y + (x * get_stride<index_t>(p.yStride.x) +
y * get_stride<index_t>(p.yStride.y) +
mapOfsOut))) = (T)(v * (scalar_t)c_fd[0]);
}
}
}
// Downsampling.
if (filterMode == MODE_SUSD || filterMode == MODE_FUSD) {
// Horizontal downsampling.
__syncthreads();
if (down == 4 && tileOutW % 4 == 0) {
// Calculate 4 pixels at a time.
for (int idx = threadIdx.x * 4; idx < tileOutW * tileUpH;
idx += blockDim.x * 4) {
int relOutX0, relUpY;
fast_div_mod<tileOutW>(relOutX0, relUpY, idx);
int relUpX0 = relOutX0 * down;
int src0 = relUpY * tileUpW + relUpX0;
vec4_t v = InternalType<T>::zero_vec4();
#pragma unroll
for (int step = 0; step < fdSize; step++) {
v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step];
v.y += s_tileUpXY[src0 + 4 + step] * (scalar_t)c_fd[step];
v.z += s_tileUpXY[src0 + 8 + step] * (scalar_t)c_fd[step];
v.w += s_tileUpXY[src0 + 12 + step] * (scalar_t)c_fd[step];
}
s_tileDownX[idx + 0] = v.x;
s_tileDownX[idx + 1] = v.y;
s_tileDownX[idx + 2] = v.z;
s_tileDownX[idx + 3] = v.w;
}
} else if ((down == 2 || down == 4) && (tileOutW % 2 == 0)) {
// Calculate 2 pixels at a time.
for (int idx = threadIdx.x * 2; idx < tileOutW * tileUpH;
idx += blockDim.x * 2) {
int relOutX0, relUpY;
fast_div_mod<tileOutW>(relOutX0, relUpY, idx);
int relUpX0 = relOutX0 * down;
int src0 = relUpY * tileUpW + relUpX0;
vec2_t v = InternalType<T>::zero_vec2();
#pragma unroll
for (int step = 0; step < fdSize; step++) {
v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step];
v.y += s_tileUpXY[src0 + down + step] * (scalar_t)c_fd[step];
}
s_tileDownX[idx + 0] = v.x;
s_tileDownX[idx + 1] = v.y;
}
} else {
// Calculate 1 pixel at a time.
for (int idx = threadIdx.x; idx < tileOutW * tileUpH;
idx += blockDim.x) {
int relOutX0, relUpY;
fast_div_mod<tileOutW>(relOutX0, relUpY, idx);
int relUpX0 = relOutX0 * down;
int src = relUpY * tileUpW + relUpX0;
scalar_t v = 0.f;
#pragma unroll
for (int step = 0; step < fdSize; step++)
v += s_tileUpXY[src + step] * (scalar_t)c_fd[step];
s_tileDownX[idx] = v;
}
}
// Vertical downsampling & store output tile.
__syncthreads();
for (int idx = threadIdx.x; idx < tileOutW * tileOutH;
idx += blockDim.x) {
int relOutX, relOutY0;
fast_div_mod<tileOutW>(relOutX, relOutY0, idx);
int relUpY0 = relOutY0 * down;
int src0 = relUpY0 * tileOutW + relOutX;
scalar_t v = 0;
#pragma unroll
for (int step = 0; step < fdSize; step++)
v += s_tileDownX[src0 + step * tileOutW] * (scalar_t)c_fd[step];
int outX = tileOutX + relOutX;
int outY = tileOutY + relOutY0;
if (outX < p.yShape.x & outY < p.yShape.y)
*((T*)((char*)p.y +
(outX * get_stride<index_t>(p.yStride.x) +
outY * get_stride<index_t>(p.yStride.y) + mapOfsOut))) = (T)v;
}
} else if (filterMode == MODE_SUFD || filterMode == MODE_FUFD) {
// Full downsampling filter.
if (down == 2) {
// 2-wide.
__syncthreads();
for (int idx = threadIdx.x * 2; idx < tileOutW * tileOutH;
idx += blockDim.x * 2) {
int relOutX0, relOutY0;
fast_div_mod<tileOutW>(relOutX0, relOutY0, idx);
int relUpX0 = relOutX0 * down;
int relUpY0 = relOutY0 * down;
int src0 = relUpY0 * tileUpW + relUpX0;
vec2_t v = InternalType<T>::zero_vec2();
#pragma unroll
for (int sy = 0; sy < fdSize; sy++)
#pragma unroll
for (int sx = 0; sx < fdSize; sx++) {
v.x += s_tileUpXY[src0 + 0 + sx + sy * tileUpW] *
(scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE];
v.y += s_tileUpXY[src0 + 2 + sx + sy * tileUpW] *
(scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE];
}
int outX = tileOutX + relOutX0;
int outY = tileOutY + relOutY0;
if ((uint32_t)outY < p.yShape.y) {
index_t ofs = outX * get_stride<index_t>(p.yStride.x) +
outY * get_stride<index_t>(p.yStride.y) + mapOfsOut;
if (outX + 0 < p.yShape.x) *((T*)((char*)p.y + ofs)) = (T)v.x;
if (outX + 1 < p.yShape.x)
*((T*)((char*)p.y + ofs + get_stride<index_t>(p.yStride.x))) =
(T)v.y;
}
}
} else if (down == 1 && !downInline) {
// Thread per pixel.
__syncthreads();
for (int idx = threadIdx.x; idx < tileOutW * tileOutH;
idx += blockDim.x) {
int relOutX0, relOutY0;
fast_div_mod<tileOutW>(relOutX0, relOutY0, idx);
scalar_t v = s_tileUpXY[idx] * (scalar_t)c_fd[0]; // 1x1 filter.
int outX = tileOutX + relOutX0;
int outY = tileOutY + relOutY0;
if ((uint32_t)outX < p.yShape.x && (uint32_t)outY < p.yShape.y)
*((T*)((char*)p.y + (outX * get_stride<index_t>(p.yStride.x) +
outY * get_stride<index_t>(p.yStride.y) +
mapOfsOut))) = (T)v;
}
}
}
if (!enableXrep) break;
}
}
//------------------------------------------------------------------------
// Compute activation function and signs for upsampled data tensor, modifying
// data tensor in-place. Used for accelerating the generic variant. Sign tensor
// is known to be contiguous, and p.x and p.s have the same z, w dimensions.
// 64-bit indexing is always used.
template <class T, bool signWrite, bool signRead>
static __global__ void filtered_lrelu_act_kernel(
filtered_lrelu_act_kernel_params p) {
typedef typename InternalType<T>::scalar_t scalar_t;
// Indexing.
int32_t x = threadIdx.x + blockIdx.x * blockDim.x;
int32_t ymax = signWrite ? p.sShape.y : p.xShape.y;
int32_t qmax =
p.xShape.z * p.xShape.w; // Combined minibatch*channel maximum index.
// Loop to accommodate oversized tensors.
for (int32_t q = blockIdx.z; q < qmax; q += gridDim.z)
for (int32_t y = blockIdx.y; y < ymax; y += gridDim.y) {
// Extract z and w (channel, minibatch index).
int32_t w = q / p.xShape.z;
int32_t z = q - w * p.xShape.z;
// Choose behavior based on sign read/write mode.
if (signWrite) {
// Process value if in p.x.
uint32_t s = 0;
if (x < p.xShape.x && y < p.xShape.y) {
int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z +
w * p.xStride.w;
T* pv = ((T*)p.x) + ix;
scalar_t v = (scalar_t)(*pv);
// Gain, LReLU, clamp.
v *= p.gain;
if (v < 0.f) {
v *= p.slope;
s = 1; // Sign.
}
if (fabsf(v) > p.clamp) {
v = InternalType<T>::clamp(v, p.clamp);
s = 2; // Clamp.
}
*pv = (T)v; // Write value.
}
// Coalesce into threads 0 and 16 of warp.
uint32_t m = (threadIdx.x & 16) ? 0xffff0000u : 0x0000ffffu;
s <<= ((threadIdx.x & 15) << 1); // Shift into place.
s |= __shfl_xor_sync(m, s, 1); // Distribute.
s |= __shfl_xor_sync(m, s, 2);
s |= __shfl_xor_sync(m, s, 4);
s |= __shfl_xor_sync(m, s, 8);
// Write signs if leader and in p.s.
if (!(threadIdx.x & 15) && x < p.sShape.x) // y is always in.
{
uint64_t is =
x + p.sShape.x * (y + (int64_t)p.sShape.y * q); // Contiguous.
((uint32_t*)p.s)[is >> 4] = s;
}
} else if (signRead) {
// Process value if in p.x.
if (x < p.xShape.x) // y is always in.
{
int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z +
w * p.xStride.w;
T* pv = ((T*)p.x) + ix;
scalar_t v = (scalar_t)(*pv);
v *= p.gain;
// Apply sign buffer offset.
uint32_t sx = x + p.sOfs.x;
uint32_t sy = y + p.sOfs.y;
// Read and apply signs if we land inside valid region of sign buffer.
if (sx < p.sShape.x && sy < p.sShape.y) {
uint64_t is =
(sx >> 2) + (p.sShape.x >> 2) *
(sy + (uint64_t)p.sShape.y * q); // Contiguous.
unsigned char s = p.s[is];
s >>= (sx & 3) << 1; // Shift into place.
if (s & 1) // Sign?
v *= p.slope;
if (s & 2) // Clamp?
v = 0.f;
}
*pv = (T)v; // Write value.
}
} else {
// Forward pass with no sign write. Process value if in p.x.
if (x < p.xShape.x) // y is always in.
{
int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z +
w * p.xStride.w;
T* pv = ((T*)p.x) + ix;
scalar_t v = (scalar_t)(*pv);
v *= p.gain;
if (v < 0.f) v *= p.slope;
if (fabsf(v) > p.clamp) v = InternalType<T>::clamp(v, p.clamp);
*pv = (T)v; // Write value.
}
}
}
}
template <class T, bool signWrite, bool signRead>
void* choose_filtered_lrelu_act_kernel(void) {
return (void*)filtered_lrelu_act_kernel<T, signWrite, signRead>;
}
//------------------------------------------------------------------------
// CUDA kernel selection.
template <class T, class index_t, bool signWrite, bool signRead>
filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(
const filtered_lrelu_kernel_params& p, int sharedKB) {
filtered_lrelu_kernel_spec s = {0};
// Return the first matching kernel.
#define CASE(SH, U, FU, D, FD, MODE, TW, TH, W, XR, WS) \
if (sharedKB >= SH) \
if ((p.fuShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_SUFD)) || \
(p.fuShape.y > 0 && (MODE == MODE_FUSD || MODE == MODE_FUFD))) \
if ((p.fdShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_FUSD)) || \
(p.fdShape.y > 0 && (MODE == MODE_SUFD || MODE == MODE_FUFD))) \
if (p.up == U && p.fuShape.x <= FU && p.fuShape.y <= FU && \
p.down == D && p.fdShape.x <= FD && p.fdShape.y <= FD) { \
static_assert((D * TW % 4) == 0, \
"down * tileWidth must be divisible by 4"); \
static_assert( \
FU % U == 0, \
"upscaling filter size must be multiple of upscaling factor"); \
static_assert(FD % D == 0, \
"downscaling filter size must be multiple of " \
"downscaling factor"); \
s.setup = (void*)setup_filters_kernel; \
s.exec = (void*) \
filtered_lrelu_kernel<T, index_t, SH, signWrite, signRead, MODE, \
U, FU, D, FD, TW, TH, W * 32, !!XR, !!WS>; \
s.tileOut = make_int2(TW, TH); \
s.numWarps = W; \
s.xrep = XR; \
s.dynamicSharedKB = (SH == 48) ? 0 : SH; \
return s; \
}
// Launch parameters for various kernel specializations.
// Small filters must be listed before large filters, otherwise the kernel for
// larger filter will always match first. Kernels that use more shared memory
// must be listed before those that use less, for the same reason.
CASE(/*sharedKB*/ 48, /*up,fu*/ 1, 1, /*down,fd*/ 1, 1, /*mode*/ MODE_FUFD,
/*tw,th,warps,xrep,wskip*/ 64, 178, 32, 0, 0) // 1t-upf1-downf1
CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 8, /*down,fd*/ 1, 1, /*mode*/ MODE_SUFD,
/*tw,th,warps,xrep,wskip*/ 152, 95, 16, 0, 0) // 4t-ups2-downf1
CASE(/*sharedKB*/ 48, /*up,fu*/ 1, 1, /*down,fd*/ 2, 8, /*mode*/ MODE_FUSD,
/*tw,th,warps,xrep,wskip*/ 56, 22, 16, 0, 0) // 4t-upf1-downs2
CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 8, /*down,fd*/ 2, 8, /*mode*/ MODE_SUSD,
/*tw,th,warps,xrep,wskip*/ 56, 29, 16, 11, 0) // 4t-ups2-downs2
CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 8, /*down,fd*/ 2, 8, /*mode*/ MODE_FUSD,
/*tw,th,warps,xrep,wskip*/ 60, 28, 16, 0, 0) // 4t-upf2-downs2
CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 8, /*down,fd*/ 2, 8, /*mode*/ MODE_SUFD,
/*tw,th,warps,xrep,wskip*/ 56, 28, 16, 0, 0) // 4t-ups2-downf2
CASE(/*sharedKB*/ 48, /*up,fu*/ 4, 16, /*down,fd*/ 2, 8, /*mode*/ MODE_SUSD,
/*tw,th,warps,xrep,wskip*/ 56, 31, 16, 11, 0) // 4t-ups4-downs2
CASE(/*sharedKB*/ 48, /*up,fu*/ 4, 16, /*down,fd*/ 2, 8, /*mode*/ MODE_SUFD,
/*tw,th,warps,xrep,wskip*/ 56, 36, 16, 0, 0) // 4t-ups4-downf2
CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 8, /*down,fd*/ 4, 16, /*mode*/ MODE_SUSD,
/*tw,th,warps,xrep,wskip*/ 16, 22, 16, 12, 0) // 4t-ups2-downs4
CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 8, /*down,fd*/ 4, 16, /*mode*/ MODE_FUSD,
/*tw,th,warps,xrep,wskip*/ 29, 15, 16, 0, 0) // 4t-upf2-downs4
CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 12, /*down,fd*/ 1, 1, /*mode*/ MODE_SUFD,
/*tw,th,warps,xrep,wskip*/ 96, 150, 28, 0, 0) // 6t-ups2-downf1
CASE(/*sharedKB*/ 48, /*up,fu*/ 1, 1, /*down,fd*/ 2, 12, /*mode*/ MODE_FUSD,
/*tw,th,warps,xrep,wskip*/ 32, 35, 24, 0, 0) // 6t-upf1-downs2
CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 12, /*down,fd*/ 2, 12, /*mode*/ MODE_SUSD,
/*tw,th,warps,xrep,wskip*/ 32, 46, 16, 10, 0) // 6t-ups2-downs2
CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 12, /*down,fd*/ 2, 12, /*mode*/ MODE_FUSD,
/*tw,th,warps,xrep,wskip*/ 58, 28, 24, 8, 0) // 6t-upf2-downs2
CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 12, /*down,fd*/ 2, 12, /*mode*/ MODE_SUFD,
/*tw,th,warps,xrep,wskip*/ 52, 28, 16, 0, 0) // 6t-ups2-downf2
CASE(/*sharedKB*/ 48, /*up,fu*/ 4, 24, /*down,fd*/ 2, 12, /*mode*/ MODE_SUSD,
/*tw,th,warps,xrep,wskip*/ 32, 51, 16, 5, 0) // 6t-ups4-downs2
CASE(/*sharedKB*/ 48, /*up,fu*/ 4, 24, /*down,fd*/ 2, 12, /*mode*/ MODE_SUFD,
/*tw,th,warps,xrep,wskip*/ 32, 56, 16, 6, 0) // 6t-ups4-downf2
CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 12, /*down,fd*/ 4, 24, /*mode*/ MODE_SUSD,
/*tw,th,warps,xrep,wskip*/ 16, 18, 16, 12, 0) // 6t-ups2-downs4
CASE(/*sharedKB*/ 96, /*up,fu*/ 2, 12, /*down,fd*/ 4, 24, /*mode*/ MODE_FUSD,
/*tw,th,warps,xrep,wskip*/ 27, 31, 32, 6, 0) // 6t-upf2-downs4 96kB
CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 12, /*down,fd*/ 4, 24, /*mode*/ MODE_FUSD,
/*tw,th,warps,xrep,wskip*/ 27, 13, 24, 0, 0) // 6t-upf2-downs4
CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 16, /*down,fd*/ 1, 1, /*mode*/ MODE_SUFD,
/*tw,th,warps,xrep,wskip*/ 148, 89, 24, 0, 0) // 8t-ups2-downf1
CASE(/*sharedKB*/ 48, /*up,fu*/ 1, 1, /*down,fd*/ 2, 16, /*mode*/ MODE_FUSD,
/*tw,th,warps,xrep,wskip*/ 32, 31, 16, 5, 0) // 8t-upf1-downs2
CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 16, /*down,fd*/ 2, 16, /*mode*/ MODE_SUSD,
/*tw,th,warps,xrep,wskip*/ 32, 41, 16, 9, 0) // 8t-ups2-downs2
CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 16, /*down,fd*/ 2, 16, /*mode*/ MODE_FUSD,
/*tw,th,warps,xrep,wskip*/ 56, 26, 24, 0, 0) // 8t-upf2-downs2
CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 16, /*down,fd*/ 2, 16, /*mode*/ MODE_SUFD,
/*tw,th,warps,xrep,wskip*/ 32, 40, 16, 0, 0) // 8t-ups2-downf2
CASE(/*sharedKB*/ 48, /*up,fu*/ 4, 32, /*down,fd*/ 2, 16, /*mode*/ MODE_SUSD,
/*tw,th,warps,xrep,wskip*/ 32, 46, 24, 5, 0) // 8t-ups4-downs2
CASE(/*sharedKB*/ 48, /*up,fu*/ 4, 32, /*down,fd*/ 2, 16, /*mode*/ MODE_SUFD,
/*tw,th,warps,xrep,wskip*/ 32, 50, 16, 0, 0) // 8t-ups4-downf2
CASE(/*sharedKB*/ 96, /*up,fu*/ 2, 16, /*down,fd*/ 4, 32, /*mode*/ MODE_SUSD,
/*tw,th,warps,xrep,wskip*/ 24, 24, 32, 12, 1) // 8t-ups2-downs4 96kB
CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 16, /*down,fd*/ 4, 32, /*mode*/ MODE_SUSD,
/*tw,th,warps,xrep,wskip*/ 16, 13, 16, 10, 1) // 8t-ups2-downs4
CASE(/*sharedKB*/ 96, /*up,fu*/ 2, 16, /*down,fd*/ 4, 32, /*mode*/ MODE_FUSD,
/*tw,th,warps,xrep,wskip*/ 25, 28, 28, 4, 0) // 8t-upf2-downs4 96kB
CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 16, /*down,fd*/ 4, 32, /*mode*/ MODE_FUSD,
/*tw,th,warps,xrep,wskip*/ 25, 10, 24, 0, 0) // 8t-upf2-downs4
#undef CASE
return s; // No kernel found.
}
//------------------------------------------------------------------------
std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu_op(
torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b,
torch::Tensor si, int up, int down, int px0, int px1, int py0, int py1,
int sx, int sy, float gain, float slope, float clamp, bool flip_filters,
bool writeSigns) {
// Set CUDA device.
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
// Validate arguments.
TORCH_CHECK(fu.device() == x.device() && fd.device() == x.device() &&
b.device() == x.device(),
"all input tensors must reside on the same device");
TORCH_CHECK(fu.dtype() == torch::kFloat && fd.dtype() == torch::kFloat,
"fu and fd must be float32");
TORCH_CHECK(b.dtype() == x.dtype(), "x and b must have the same dtype");
TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat,
"x and b must be float16 or float32");
TORCH_CHECK(x.dim() == 4, "x must be rank 4");
TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX &&
x.size(3) <= INT_MAX,
"x is too large");
TORCH_CHECK(x.numel() > 0, "x is empty");
TORCH_CHECK(
(fu.dim() == 1 || fu.dim() == 2) && (fd.dim() == 1 || fd.dim() == 2),
"fu and fd must be rank 1 or 2");
TORCH_CHECK(fu.size(0) <= INT_MAX && fu.size(-1) <= INT_MAX,
"fu is too large");
TORCH_CHECK(fd.size(0) <= INT_MAX && fd.size(-1) <= INT_MAX,
"fd is too large");
TORCH_CHECK(fu.numel() > 0, "fu is empty");
TORCH_CHECK(fd.numel() > 0, "fd is empty");
TORCH_CHECK(b.dim() == 1 && b.size(0) == x.size(1),
"b must be a vector with the same number of channels as x");
TORCH_CHECK(up >= 1 && down >= 1, "up and down must be at least 1");
// Figure out how much shared memory is available on the device.
int maxSharedBytes = 0;
AT_CUDA_CHECK(cudaDeviceGetAttribute(&maxSharedBytes,
cudaDevAttrMaxSharedMemoryPerBlockOptin,
x.device().index()));
int sharedKB = maxSharedBytes >> 10;
// Populate enough launch parameters to check if a CUDA kernel exists.
filtered_lrelu_kernel_params p;
p.up = up;
p.down = down;
p.fuShape =
make_int2((int)fu.size(-1),
fu.dim() == 2 ? (int)fu.size(0)
: 0); // shape [n, 0] indicates separable filter.
p.fdShape = make_int2((int)fd.size(-1), fd.dim() == 2 ? (int)fd.size(0) : 0);
filtered_lrelu_kernel_spec test_spec =
choose_filtered_lrelu_kernel<float, int32_t, false, false>(p, sharedKB);
if (!test_spec.exec) {
// No kernel found - return empty tensors and indicate missing kernel with
// return code of -1.
return std::make_tuple(torch::Tensor(), torch::Tensor(), -1);
}
// Input/output element size.
int64_t sz = (x.dtype() == torch::kHalf) ? 2 : 4;
// Input sizes.
int64_t xw = (int)x.size(3);
int64_t xh = (int)x.size(2);
int64_t fut_w = (int)fu.size(-1) - 1;
int64_t fut_h = (int)fu.size(0) - 1;
int64_t fdt_w = (int)fd.size(-1) - 1;
int64_t fdt_h = (int)fd.size(0) - 1;
// Logical size of upsampled buffer.
int64_t cw = xw * up + (px0 + px1) - fut_w;
int64_t ch = xh * up + (py0 + py1) - fut_h;
TORCH_CHECK(
cw > fdt_w && ch > fdt_h,
"upsampled buffer must be at least the size of downsampling filter");
TORCH_CHECK(cw <= INT_MAX && ch <= INT_MAX, "upsampled buffer is too large");
// Compute output size and allocate.
int64_t yw = (cw - fdt_w + (down - 1)) / down;
int64_t yh = (ch - fdt_h + (down - 1)) / down;
TORCH_CHECK(yw > 0 && yh > 0, "output must be at least 1x1");
TORCH_CHECK(yw <= INT_MAX && yh <= INT_MAX, "output is too large");
torch::Tensor y = torch::empty({x.size(0), x.size(1), yh, yw}, x.options(),
x.suggest_memory_format());
// Allocate sign tensor.
torch::Tensor so;
torch::Tensor s = si;
bool readSigns = !!s.numel();
int64_t sw_active = 0; // Active width of sign tensor.
if (writeSigns) {
sw_active = yw * down - (down - 1) + fdt_w; // Active width in elements.
int64_t sh = yh * down - (down - 1) + fdt_h; // Height = active height.
int64_t sw = (sw_active + 15) & ~15; // Width = active width in elements,
// rounded up to multiple of 16.
TORCH_CHECK(sh <= INT_MAX && (sw >> 2) <= INT_MAX, "signs is too large");
s = so = torch::empty({x.size(0), x.size(1), sh, sw >> 2},
x.options().dtype(torch::kUInt8),
at::MemoryFormat::Contiguous);
} else if (readSigns)
sw_active = s.size(3) << 2;
// Validate sign tensor if in use.
if (readSigns || writeSigns) {
TORCH_CHECK(s.is_contiguous(), "signs must be contiguous");
TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8");
TORCH_CHECK(s.device() == x.device(),
"signs must reside on the same device as x");
TORCH_CHECK(s.dim() == 4, "signs must be rank 4");
TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1),
"signs must have same batch & channels as x");
TORCH_CHECK(s.size(2) <= INT_MAX && s.size(3) <= INT_MAX,
"signs is too large");
}
// Populate rest of CUDA kernel parameters.
p.x = x.data_ptr();
p.y = y.data_ptr();
p.b = b.data_ptr();
p.s = (readSigns || writeSigns) ? s.data_ptr<unsigned char>() : 0;
p.fu = fu.data_ptr<float>();
p.fd = fd.data_ptr<float>();
p.pad0 = make_int2(px0, py0);
p.gain = gain;
p.slope = slope;
p.clamp = clamp;
p.flip = (flip_filters) ? 1 : 0;
p.xShape =
make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
p.yShape =
make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
p.sShape = (readSigns || writeSigns)
? make_int2((int)s.size(3), (int)s.size(2))
: make_int2(0, 0); // Width is in bytes. Contiguous.
p.sOfs = make_int2(sx, sy);
p.swLimit = (sw_active + 3) >> 2; // Rounded up to bytes.
// x, y, b strides are in bytes.
p.xStride = make_longlong4(sz * x.stride(3), sz * x.stride(2),
sz * x.stride(1), sz * x.stride(0));
p.yStride = make_longlong4(sz * y.stride(3), sz * y.stride(2),
sz * y.stride(1), sz * y.stride(0));
p.bStride = sz * b.stride(0);
// fu, fd strides are in elements.
p.fuStride =
make_longlong3(fu.stride(-1), fu.dim() == 2 ? fu.stride(0) : 0, 0);
p.fdStride =
make_longlong3(fd.stride(-1), fd.dim() == 2 ? fd.stride(0) : 0, 0);
// Determine if indices don't fit in int32. Support negative strides although
// Torch currently never produces those.
bool index64b = false;
if (std::abs(p.bStride * x.size(1)) > INT_MAX) index64b = true;
if (std::min(x.size(0) * p.xStride.w, 0ll) +
std::min(x.size(1) * p.xStride.z, 0ll) +
std::min(x.size(2) * p.xStride.y, 0ll) +
std::min(x.size(3) * p.xStride.x, 0ll) <
-INT_MAX)
index64b = true;
if (std::max(x.size(0) * p.xStride.w, 0ll) +
std::max(x.size(1) * p.xStride.z, 0ll) +
std::max(x.size(2) * p.xStride.y, 0ll) +
std::max(x.size(3) * p.xStride.x, 0ll) >
INT_MAX)
index64b = true;
if (std::min(y.size(0) * p.yStride.w, 0ll) +
std::min(y.size(1) * p.yStride.z, 0ll) +
std::min(y.size(2) * p.yStride.y, 0ll) +
std::min(y.size(3) * p.yStride.x, 0ll) <
-INT_MAX)
index64b = true;
if (std::max(y.size(0) * p.yStride.w, 0ll) +
std::max(y.size(1) * p.yStride.z, 0ll) +
std::max(y.size(2) * p.yStride.y, 0ll) +
std::max(y.size(3) * p.yStride.x, 0ll) >
INT_MAX)
index64b = true;
if (s.numel() > INT_MAX) index64b = true;
// Choose CUDA kernel.
filtered_lrelu_kernel_spec spec = {0};
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
x.scalar_type(), "filtered_lrelu_cuda", [&] {
if constexpr (sizeof(scalar_t) <=
4) // Exclude doubles. constexpr prevents template
// instantiation.
{
// Choose kernel based on index type, datatype and sign read/write
// modes.
if (!index64b && writeSigns && !readSigns)
spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, true, false>(
p, sharedKB);
else if (!index64b && !writeSigns && readSigns)
spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, false, true>(
p, sharedKB);
else if (!index64b && !writeSigns && !readSigns)
spec =
choose_filtered_lrelu_kernel<scalar_t, int32_t, false, false>(
p, sharedKB);
else if (index64b && writeSigns && !readSigns)
spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, true, false>(
p, sharedKB);
else if (index64b && !writeSigns && readSigns)
spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, false, true>(
p, sharedKB);
else if (index64b && !writeSigns && !readSigns)
spec =
choose_filtered_lrelu_kernel<scalar_t, int64_t, false, false>(
p, sharedKB);
}
});
TORCH_CHECK(
spec.exec,
"internal error - CUDA kernel not found") // This should not happen
// because we tested earlier
// that kernel exists.
// Launch CUDA kernel.
void* args[] = {&p};
int bx = spec.numWarps * 32;
int gx = (p.yShape.x - 1) / spec.tileOut.x + 1;
int gy = (p.yShape.y - 1) / spec.tileOut.y + 1;
int gz = p.yShape.z * p.yShape.w;
// Repeat multiple horizontal tiles in a CTA?
if (spec.xrep) {
p.tilesXrep = spec.xrep;
p.tilesXdim = gx;
gx = (gx + p.tilesXrep - 1) / p.tilesXrep;
std::swap(gx, gy);
} else {
p.tilesXrep = 0;
p.tilesXdim = 0;
}
// Launch filter setup kernel.
AT_CUDA_CHECK(cudaLaunchKernel(spec.setup, 1, 1024, args, 0,
at::cuda::getCurrentCUDAStream()));
// Copy kernels to constant memory.
if (writeSigns && !readSigns)
AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream())));
else if (!writeSigns && readSigns)
AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream())));
else if (!writeSigns && !readSigns)
AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream())));
// Set cache and shared memory configurations for main kernel.
AT_CUDA_CHECK(cudaFuncSetCacheConfig(spec.exec, cudaFuncCachePreferShared));
if (spec.dynamicSharedKB) // Need dynamically allocated shared memory?
AT_CUDA_CHECK(cudaFuncSetAttribute(
spec.exec, cudaFuncAttributeMaxDynamicSharedMemorySize,
spec.dynamicSharedKB << 10));
AT_CUDA_CHECK(
cudaFuncSetSharedMemConfig(spec.exec, cudaSharedMemBankSizeFourByte));
// Launch main kernel.
const int maxSubGz = 65535; // CUDA maximum for block z dimension.
for (int zofs = 0; zofs < gz;
zofs += maxSubGz) // Do multiple launches if gz is too big.
{
p.blockZofs = zofs;
int subGz = std::min(maxSubGz, gz - zofs);
AT_CUDA_CHECK(cudaLaunchKernel(spec.exec, dim3(gx, gy, subGz), bx, args,
spec.dynamicSharedKB << 10,
at::cuda::getCurrentCUDAStream()));
}
// Done.
return std::make_tuple(y, so, 0);
}
//------------------------------------------------------------------------
torch::Tensor filtered_lrelu_act_op(torch::Tensor x, torch::Tensor si, int sx,
int sy, float gain, float slope,
float clamp, bool writeSigns) {
// Set CUDA device.
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
// Validate arguments.
TORCH_CHECK(x.dim() == 4, "x must be rank 4");
TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX &&
x.size(3) <= INT_MAX,
"x is too large");
TORCH_CHECK(x.numel() > 0, "x is empty");
TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat ||
x.dtype() == torch::kDouble,
"x must be float16, float32 or float64");
// Output signs if we don't have sign input.
torch::Tensor so;
torch::Tensor s = si;
bool readSigns = !!s.numel();
if (writeSigns) {
int64_t sw = x.size(3);
sw = (sw + 15) & ~15; // Round to a multiple of 16 for coalescing.
s = so = torch::empty({x.size(0), x.size(1), x.size(2), sw >> 2},
x.options().dtype(torch::kUInt8),
at::MemoryFormat::Contiguous);
}
// Validate sign tensor if in use.
if (readSigns || writeSigns) {
TORCH_CHECK(s.is_contiguous(), "signs must be contiguous");
TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8");
TORCH_CHECK(s.device() == x.device(),
"signs must reside on the same device as x");
TORCH_CHECK(s.dim() == 4, "signs must be rank 4");
TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1),
"signs must have same batch & channels as x");
TORCH_CHECK(s.size(2) <= INT_MAX && (s.size(3) << 2) <= INT_MAX,
"signs tensor is too large");
}
// Initialize CUDA kernel parameters.
filtered_lrelu_act_kernel_params p;
p.x = x.data_ptr();
p.s = (readSigns || writeSigns) ? s.data_ptr<unsigned char>() : 0;
p.gain = gain;
p.slope = slope;
p.clamp = clamp;
p.xShape =
make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
p.xStride =
make_longlong4(x.stride(3), x.stride(2), x.stride(1), x.stride(0));
p.sShape = (readSigns || writeSigns)
? make_int2((int)s.size(3) << 2, (int)s.size(2))
: make_int2(0, 0); // Width is in elements. Contiguous.
p.sOfs = make_int2(sx, sy);
// Choose CUDA kernel.
void* func = 0;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
x.scalar_type(), "filtered_lrelu_act_cuda", [&] {
if (writeSigns)
func = choose_filtered_lrelu_act_kernel<scalar_t, true, false>();
else if (readSigns)
func = choose_filtered_lrelu_act_kernel<scalar_t, false, true>();
else
func = choose_filtered_lrelu_act_kernel<scalar_t, false, false>();
});
TORCH_CHECK(func, "internal error - CUDA kernel not found");
// Launch CUDA kernel.
void* args[] = {&p};
int bx = 128; // 4 warps per block.
// Logical size of launch = writeSigns ? p.s : p.x
uint32_t gx = writeSigns ? p.sShape.x : p.xShape.x;
uint32_t gy = writeSigns ? p.sShape.y : p.xShape.y;
uint32_t gz =
p.xShape.z * p.xShape.w; // Same as in p.sShape if signs are in use.
gx = (gx - 1) / bx + 1;
// Make sure grid y and z dimensions are within CUDA launch limits. Kernel
// loops internally to do the rest.
const uint32_t gmax = 65535;
gy = std::min(gy, gmax);
gz = std::min(gz, gmax);
// Launch.
AT_CUDA_CHECK(cudaLaunchKernel(func, dim3(gx, gy, gz), bx, args, 0,
at::cuda::getCurrentCUDAStream()));
return so;
}
// Modified from
// https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d_kernel.cu
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// This work is made available under the Nvidia Source Code License-NC.
// To view a copy of this license, visit
// https://nvlabs.github.io/stylegan2/license.html
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
#include <c10/util/Half.h>
#include <torch/extension.h>
#include "pytorch_cuda_helper.hpp"
struct upfirdn2d_kernel_params {
const void *x;
const float *f;
void *y;
int2 up;
int2 down;
int2 pad0;
int flip;
float gain;
int4 inSize; // [width, height, channel, batch]
int4 inStride;
int2 filterSize; // [width, height]
int2 filterStride;
int4 outSize; // [width, height, channel, batch]
int4 outStride;
int sizeMinor;
int sizeMajor;
int loopMinor;
int loopMajor;
int loopX;
int launchMinor;
int launchMajor;
};
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <torch/types.h>
//------------------------------------------------------------------------
// CUDA kernel specialization.
#include <ATen/cuda/CUDAApplyUtils.cuh>
struct upfirdn2d_kernel_spec {
void *kernel;
int tileOutW;
int tileOutH;
int loopMinor;
int loopX;
};
static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
int c = a / b;
//------------------------------------------------------------------------
// CUDA kernel selection.
if (c * b > a) {
c--;
}
template <class T>
upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params &p);
//------------------------------------------------------------------------
return c;
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
//------------------------------------------------------------------------
// Helpers.
template <class T>
struct InternalType;
template <>
struct InternalType<double> {
typedef double scalar_t;
};
template <>
struct InternalType<float> {
typedef float scalar_t;
};
template <>
struct InternalType<c10::Half> {
typedef float scalar_t;
};
static __device__ __forceinline__ int floor_div(int a, int b) {
int t = 1 - a / b;
return (a + t * b) / b - t;
}
struct UpFirDn2DKernelParams {
int up_x;
int up_y;
int down_x;
int down_y;
int pad_x0;
int pad_x1;
int pad_y0;
int pad_y1;
int major_dim;
int in_h;
int in_w;
int minor_dim;
int kernel_h;
int kernel_w;
int out_h;
int out_w;
int loop_major;
int loop_x;
};
//------------------------------------------------------------------------
// Generic CUDA implementation for large filters.
template <typename scalar_t>
__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
const scalar_t *kernel,
const UpFirDn2DKernelParams p) {
int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
int out_y = minor_idx / p.minor_dim;
minor_idx -= out_y * p.minor_dim;
int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
int major_idx_base = blockIdx.z * p.loop_major;
if (out_x_base >= p.out_w || out_y >= p.out_h ||
major_idx_base >= p.major_dim) {
template <class T>
static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p) {
typedef typename InternalType<T>::scalar_t scalar_t;
// Calculate thread index.
int minorBase = blockIdx.x * blockDim.x + threadIdx.x;
int outY = minorBase / p.launchMinor;
minorBase -= outY * p.launchMinor;
int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
int majorBase = blockIdx.z * p.loopMajor;
if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor)
return;
}
int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
for (int loop_major = 0, major_idx = major_idx_base;
loop_major < p.loop_major && major_idx < p.major_dim;
loop_major++, major_idx++) {
for (int loop_x = 0, out_x = out_x_base;
loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
const scalar_t *x_p =
&input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
minor_idx];
const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
int x_px = p.minor_dim;
int k_px = -p.up_x;
int x_py = p.in_w * p.minor_dim;
int k_py = -p.up_y * p.kernel_w;
scalar_t v = 0.0f;
for (int y = 0; y < h; y++) {
for (int x = 0; x < w; x++) {
v += static_cast<scalar_t>(*x_p) * static_cast<scalar_t>(*k_p);
x_p += x_px;
k_p += k_px;
// Setup Y receptive field.
int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y;
int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y);
int h =
min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY;
int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y;
if (p.flip) filterY = p.filterSize.y - 1 - filterY;
// Loop over major, minor, and X.
for (int majorIdx = 0, major = majorBase;
majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
for (int minorIdx = 0, minor = minorBase;
minorIdx < p.loopMinor & minor < p.sizeMinor;
minorIdx++, minor += p.launchMinor) {
int nc = major * p.sizeMinor + minor;
int n = nc / p.inSize.z;
int c = nc - n * p.inSize.z;
for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x;
loopX++, outX += blockDim.y) {
// Setup X receptive field.
int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x;
int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x);
int w =
min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) -
inX;
int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x;
if (p.flip) filterX = p.filterSize.x - 1 - filterX;
// Initialize pointers.
const T *xp =
&((const T *)p.x)[inX * p.inStride.x + inY * p.inStride.y +
c * p.inStride.z + n * p.inStride.w];
const float *fp =
&p.f[filterX * p.filterStride.x + filterY * p.filterStride.y];
int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x;
int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y;
// Inner loop.
scalar_t v = 0;
for (int y = 0; y < h; y++) {
for (int x = 0; x < w; x++) {
v += (scalar_t)(*xp) * (scalar_t)(*fp);
xp += p.inStride.x;
fp += filterStepX;
}
xp += p.inStride.y - w * p.inStride.x;
fp += filterStepY - w * filterStepX;
}
x_p += x_py - w * x_px;
k_p += k_py - w * k_px;
// Store result.
v *= p.gain;
((T *)p.y)[outX * p.outStride.x + outY * p.outStride.y +
c * p.outStride.z + n * p.outStride.w] = (T)v;
}
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
minor_idx] = v;
}
}
}
template <typename scalar_t, int up_x, int up_y, int down_x, int down_y,
int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
const scalar_t *kernel,
const UpFirDn2DKernelParams p) {
const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
__shared__ volatile float sk[kernel_h][kernel_w];
__shared__ volatile float sx[tile_in_h][tile_in_w];
int minor_idx = blockIdx.x;
int tile_out_y = minor_idx / p.minor_dim;
minor_idx -= tile_out_y * p.minor_dim;
tile_out_y *= tile_out_h;
int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
int major_idx_base = blockIdx.z * p.loop_major;
if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
major_idx_base >= p.major_dim) {
//------------------------------------------------------------------------
// Specialized CUDA implementation for small filters.
template <class T, int upx, int upy, int downx, int downy, int filterW,
int filterH, int tileOutW, int tileOutH, int loopMinor>
static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p) {
typedef typename InternalType<T>::scalar_t scalar_t;
const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1;
const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1;
__shared__ volatile scalar_t sf[filterH][filterW];
__shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor];
// Calculate tile index.
int minorBase = blockIdx.x;
int tileOutY = minorBase / p.launchMinor;
minorBase -= tileOutY * p.launchMinor;
minorBase *= loopMinor;
tileOutY *= tileOutH;
int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
int majorBase = blockIdx.z * p.loopMajor;
if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y |
majorBase >= p.sizeMajor)
return;
}
for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
tap_idx += blockDim.x) {
int ky = tap_idx / kernel_w;
int kx = tap_idx - ky * kernel_w;
scalar_t v = 0.0;
if (kx < p.kernel_w & ky < p.kernel_h) {
v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
// Load filter (flipped).
for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW;
tapIdx += blockDim.x) {
int fy = tapIdx / filterW;
int fx = tapIdx - fy * filterW;
scalar_t v = 0;
if (fx < p.filterSize.x & fy < p.filterSize.y) {
int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx;
int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy;
v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y];
}
sk[ky][kx] = v;
sf[fy][fx] = v;
}
for (int loop_major = 0, major_idx = major_idx_base;
loop_major < p.loop_major & major_idx < p.major_dim;
loop_major++, major_idx++) {
for (int loop_x = 0, tile_out_x = tile_out_x_base;
loop_x < p.loop_x & tile_out_x < p.out_w;
loop_x++, tile_out_x += tile_out_w) {
int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
int tile_in_x = floor_div(tile_mid_x, up_x);
int tile_in_y = floor_div(tile_mid_y, up_y);
// Loop over major and X.
for (int majorIdx = 0, major = majorBase;
majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) {
int baseNC = major * p.sizeMinor + minorBase;
int n = baseNC / p.inSize.z;
int baseC = baseNC - n * p.inSize.z;
for (int loopX = 0, tileOutX = tileOutXBase;
loopX < p.loopX & tileOutX < p.outSize.x;
loopX++, tileOutX += tileOutW) {
// Load input pixels.
int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x;
int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y;
int tileInX = floor_div(tileMidX, upx);
int tileInY = floor_div(tileMidY, upy);
__syncthreads();
for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
in_idx += blockDim.x) {
int rel_in_y = in_idx / tile_in_w;
int rel_in_x = in_idx - rel_in_y * tile_in_w;
int in_x = rel_in_x + tile_in_x;
int in_y = rel_in_y + tile_in_y;
scalar_t v = 0.0;
if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
p.minor_dim +
minor_idx];
}
sx[rel_in_y][rel_in_x] = v;
for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor;
inIdx += blockDim.x) {
int relC = inIdx;
int relInX = relC / loopMinor;
int relInY = relInX / tileInW;
relC -= relInX * loopMinor;
relInX -= relInY * tileInW;
int c = baseC + relC;
int inX = tileInX + relInX;
int inY = tileInY + relInY;
scalar_t v = 0;
if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y &
c < p.inSize.z)
v = (scalar_t)(
(const T *)p.x)[inX * p.inStride.x + inY * p.inStride.y +
c * p.inStride.z + n * p.inStride.w];
sx[relInY][relInX][relC] = v;
}
// Loop over output pixels.
__syncthreads();
for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
out_idx += blockDim.x) {
int rel_out_y = out_idx / tile_out_w;
int rel_out_x = out_idx - rel_out_y * tile_out_w;
int out_x = rel_out_x + tile_out_x;
int out_y = rel_out_y + tile_out_y;
int mid_x = tile_mid_x + rel_out_x * down_x;
int mid_y = tile_mid_y + rel_out_y * down_y;
int in_x = floor_div(mid_x, up_x);
int in_y = floor_div(mid_y, up_y);
int rel_in_x = in_x - tile_in_x;
int rel_in_y = in_y - tile_in_y;
int kernel_x = (in_x + 1) * up_x - mid_x - 1;
int kernel_y = (in_y + 1) * up_y - mid_y - 1;
scalar_t v = 0.0;
for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor;
outIdx += blockDim.x) {
int relC = outIdx;
int relOutX = relC / loopMinor;
int relOutY = relOutX / tileOutW;
relC -= relOutX * loopMinor;
relOutX -= relOutY * tileOutW;
int c = baseC + relC;
int outX = tileOutX + relOutX;
int outY = tileOutY + relOutY;
// Setup receptive field.
int midX = tileMidX + relOutX * downx;
int midY = tileMidY + relOutY * downy;
int inX = floor_div(midX, upx);
int inY = floor_div(midY, upy);
int relInX = inX - tileInX;
int relInY = inY - tileInY;
int filterX = (inX + 1) * upx - midX - 1; // flipped
int filterY = (inY + 1) * upy - midY - 1; // flipped
// Inner loop.
if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z) {
scalar_t v = 0;
#pragma unroll
for (int y = 0; y < kernel_h / up_y; y++)
for (int y = 0; y < filterH / upy; y++)
#pragma unroll
for (int x = 0; x < kernel_w / up_x; x++)
v += sx[rel_in_y + y][rel_in_x + x] *
sk[kernel_y + y * up_y][kernel_x + x * up_x];
if (out_x < p.out_w & out_y < p.out_h) {
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
minor_idx] = v;
for (int x = 0; x < filterW / upx; x++)
v += sx[relInY + y][relInX + x][relC] *
sf[filterY + y * upy][filterX + x * upx];
v *= p.gain;
((T *)p.y)[outX * p.outStride.x + outY * p.outStride.y +
c * p.outStride.z + n * p.outStride.w] = (T)v;
}
}
}
}
}
torch::Tensor upfirdn2d_op(const torch::Tensor &input,
const torch::Tensor &kernel, int up_x, int up_y,
int down_x, int down_y, int pad_x0, int pad_x1,
int pad_y0, int pad_y1) {
int curDevice = -1;
cudaGetDevice(&curDevice);
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
UpFirDn2DKernelParams p;
auto x = input.contiguous();
auto k = kernel.contiguous();
p.major_dim = x.size(0);
p.in_h = x.size(1);
p.in_w = x.size(2);
p.minor_dim = x.size(3);
p.kernel_h = k.size(0);
p.kernel_w = k.size(1);
p.up_x = up_x;
p.up_y = up_y;
p.down_x = down_x;
p.down_y = down_y;
p.pad_x0 = pad_x0;
p.pad_x1 = pad_x1;
p.pad_y0 = pad_y0;
p.pad_y1 = pad_y1;
p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
p.down_y;
p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
p.down_x;
auto out =
at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
int mode = -1;
int tile_out_h = -1;
int tile_out_w = -1;
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
p.kernel_h <= 4 && p.kernel_w <= 4) {
mode = 1;
tile_out_h = 16;
tile_out_w = 64;
//------------------------------------------------------------------------
// CUDA kernel selection.
template <class T>
upfirdn2d_kernel_spec choose_upfirdn2d_kernel(
const upfirdn2d_kernel_params &p) {
int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y;
upfirdn2d_kernel_spec spec = {(void *)upfirdn2d_kernel_large<T>, -1, -1, 1,
4}; // contiguous
if (s == 1)
spec = {(void *)upfirdn2d_kernel_large<T>, -1, -1, 4, 1}; // channels_last
// No up/downsampling.
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) {
// contiguous
if (s != 1 && fx <= 24 && fy <= 24)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 24, 24, 64, 32, 1>,
64, 32, 1, 1};
if (s != 1 && fx <= 16 && fy <= 16)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 16, 16, 64, 32, 1>,
64, 32, 1, 1};
if (s != 1 && fx <= 7 && fy <= 7)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 7, 7, 64, 16, 1>,
64, 16, 1, 1};
if (s != 1 && fx <= 6 && fy <= 6)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 6, 6, 64, 16, 1>,
64, 16, 1, 1};
if (s != 1 && fx <= 5 && fy <= 5)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 5, 5, 64, 16, 1>,
64, 16, 1, 1};
if (s != 1 && fx <= 4 && fy <= 4)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 4, 4, 64, 16, 1>,
64, 16, 1, 1};
if (s != 1 && fx <= 3 && fy <= 3)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 3, 3, 64, 16, 1>,
64, 16, 1, 1};
if (s != 1 && fx <= 24 && fy <= 1)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 24, 1, 128, 8, 1>,
128, 8, 1, 1};
if (s != 1 && fx <= 16 && fy <= 1)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 16, 1, 128, 8, 1>,
128, 8, 1, 1};
if (s != 1 && fx <= 8 && fy <= 1)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 8, 1, 128, 8, 1>,
128, 8, 1, 1};
if (s != 1 && fx <= 1 && fy <= 24)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 1, 24, 32, 32, 1>,
32, 32, 1, 1};
if (s != 1 && fx <= 1 && fy <= 16)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 1, 16, 32, 32, 1>,
32, 32, 1, 1};
if (s != 1 && fx <= 1 && fy <= 8)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 1, 8, 32, 32, 1>,
32, 32, 1, 1};
// channels_last
if (s == 1 && fx <= 24 && fy <= 24)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 24, 24, 32, 32, 1>,
32, 32, 1, 1};
if (s == 1 && fx <= 16 && fy <= 16)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 16, 16, 32, 32, 1>,
32, 32, 1, 1};
if (s == 1 && fx <= 7 && fy <= 7)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 7, 7, 16, 16, 8>,
16, 16, 8, 1};
if (s == 1 && fx <= 6 && fy <= 6)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 6, 6, 16, 16, 8>,
16, 16, 8, 1};
if (s == 1 && fx <= 5 && fy <= 5)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 5, 5, 16, 16, 8>,
16, 16, 8, 1};
if (s == 1 && fx <= 4 && fy <= 4)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 4, 4, 16, 16, 8>,
16, 16, 8, 1};
if (s == 1 && fx <= 3 && fy <= 3)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 3, 3, 16, 16, 8>,
16, 16, 8, 1};
if (s == 1 && fx <= 24 && fy <= 1)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 24, 1, 128, 1, 16>,
128, 1, 16, 1};
if (s == 1 && fx <= 16 && fy <= 1)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 16, 1, 128, 1, 16>,
128, 1, 16, 1};
if (s == 1 && fx <= 8 && fy <= 1)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 8, 1, 128, 1, 16>,
128, 1, 16, 1};
if (s == 1 && fx <= 1 && fy <= 24)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 1, 24, 1, 128, 16>,
1, 128, 16, 1};
if (s == 1 && fx <= 1 && fy <= 16)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 1, 16, 1, 128, 16>,
1, 128, 16, 1};
if (s == 1 && fx <= 1 && fy <= 8)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 1, 8, 1, 128, 16>,
1, 128, 16, 1};
}
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
p.kernel_h <= 3 && p.kernel_w <= 3) {
mode = 2;
tile_out_h = 16;
tile_out_w = 64;
// 2x upsampling.
if (p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) {
// contiguous
if (s != 1 && fx <= 24 && fy <= 24)
spec = {(void *)upfirdn2d_kernel_small<T, 2, 2, 1, 1, 24, 24, 64, 32, 1>,
64, 32, 1, 1};
if (s != 1 && fx <= 16 && fy <= 16)
spec = {(void *)upfirdn2d_kernel_small<T, 2, 2, 1, 1, 16, 16, 64, 32, 1>,
64, 32, 1, 1};
if (s != 1 && fx <= 8 && fy <= 8)
spec = {(void *)upfirdn2d_kernel_small<T, 2, 2, 1, 1, 8, 8, 64, 16, 1>,
64, 16, 1, 1};
if (s != 1 && fx <= 6 && fy <= 6)
spec = {(void *)upfirdn2d_kernel_small<T, 2, 2, 1, 1, 6, 6, 64, 16, 1>,
64, 16, 1, 1};
if (s != 1 && fx <= 4 && fy <= 4)
spec = {(void *)upfirdn2d_kernel_small<T, 2, 2, 1, 1, 4, 4, 64, 16, 1>,
64, 16, 1, 1};
if (s != 1 && fx <= 2 && fy <= 2)
spec = {(void *)upfirdn2d_kernel_small<T, 2, 2, 1, 1, 2, 2, 64, 16, 1>,
64, 16, 1, 1};
// channels_last
if (s == 1 && fx <= 24 && fy <= 24)
spec = {(void *)upfirdn2d_kernel_small<T, 2, 2, 1, 1, 24, 24, 32, 32, 1>,
32, 32, 1, 1};
if (s == 1 && fx <= 16 && fy <= 16)
spec = {(void *)upfirdn2d_kernel_small<T, 2, 2, 1, 1, 16, 16, 32, 32, 1>,
32, 32, 1, 1};
if (s == 1 && fx <= 8 && fy <= 8)
spec = {(void *)upfirdn2d_kernel_small<T, 2, 2, 1, 1, 8, 8, 16, 16, 8>,
16, 16, 8, 1};
if (s == 1 && fx <= 6 && fy <= 6)
spec = {(void *)upfirdn2d_kernel_small<T, 2, 2, 1, 1, 6, 6, 16, 16, 8>,
16, 16, 8, 1};
if (s == 1 && fx <= 4 && fy <= 4)
spec = {(void *)upfirdn2d_kernel_small<T, 2, 2, 1, 1, 4, 4, 16, 16, 8>,
16, 16, 8, 1};
if (s == 1 && fx <= 2 && fy <= 2)
spec = {(void *)upfirdn2d_kernel_small<T, 2, 2, 1, 1, 2, 2, 16, 16, 8>,
16, 16, 8, 1};
}
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
p.kernel_h <= 4 && p.kernel_w <= 4) {
mode = 3;
tile_out_h = 16;
tile_out_w = 64;
if (p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) {
// contiguous
if (s != 1 && fx <= 24 && fy <= 1)
spec = {(void *)upfirdn2d_kernel_small<T, 2, 1, 1, 1, 24, 1, 128, 8, 1>,
128, 8, 1, 1};
if (s != 1 && fx <= 16 && fy <= 1)
spec = {(void *)upfirdn2d_kernel_small<T, 2, 1, 1, 1, 16, 1, 128, 8, 1>,
128, 8, 1, 1};
if (s != 1 && fx <= 8 && fy <= 1)
spec = {(void *)upfirdn2d_kernel_small<T, 2, 1, 1, 1, 8, 1, 128, 8, 1>,
128, 8, 1, 1};
// channels_last
if (s == 1 && fx <= 24 && fy <= 1)
spec = {(void *)upfirdn2d_kernel_small<T, 2, 1, 1, 1, 24, 1, 128, 1, 16>,
128, 1, 16, 1};
if (s == 1 && fx <= 16 && fy <= 1)
spec = {(void *)upfirdn2d_kernel_small<T, 2, 1, 1, 1, 16, 1, 128, 1, 16>,
128, 1, 16, 1};
if (s == 1 && fx <= 8 && fy <= 1)
spec = {(void *)upfirdn2d_kernel_small<T, 2, 1, 1, 1, 8, 1, 128, 1, 16>,
128, 1, 16, 1};
}
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
p.kernel_h <= 2 && p.kernel_w <= 2) {
mode = 4;
tile_out_h = 16;
tile_out_w = 64;
if (p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) {
// contiguous
if (s != 1 && fx <= 1 && fy <= 24)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 2, 1, 1, 1, 24, 32, 32, 1>,
32, 32, 1, 1};
if (s != 1 && fx <= 1 && fy <= 16)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 2, 1, 1, 1, 16, 32, 32, 1>,
32, 32, 1, 1};
if (s != 1 && fx <= 1 && fy <= 8)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 2, 1, 1, 1, 8, 32, 32, 1>,
32, 32, 1, 1};
// channels_last
if (s == 1 && fx <= 1 && fy <= 24)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 2, 1, 1, 1, 24, 1, 128, 16>,
1, 128, 16, 1};
if (s == 1 && fx <= 1 && fy <= 16)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 2, 1, 1, 1, 16, 1, 128, 16>,
1, 128, 16, 1};
if (s == 1 && fx <= 1 && fy <= 8)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 2, 1, 1, 1, 8, 1, 128, 16>,
1, 128, 16, 1};
}
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
p.kernel_h <= 4 && p.kernel_w <= 4) {
mode = 5;
tile_out_h = 8;
tile_out_w = 32;
// 2x downsampling.
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) {
// contiguous
if (s != 1 && fx <= 24 && fy <= 24)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 2, 24, 24, 32, 16, 1>,
32, 16, 1, 1};
if (s != 1 && fx <= 16 && fy <= 16)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 2, 16, 16, 32, 16, 1>,
32, 16, 1, 1};
if (s != 1 && fx <= 8 && fy <= 8)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 2, 8, 8, 32, 8, 1>, 32,
8, 1, 1};
if (s != 1 && fx <= 6 && fy <= 6)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 2, 6, 6, 32, 8, 1>, 32,
8, 1, 1};
if (s != 1 && fx <= 4 && fy <= 4)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 2, 4, 4, 32, 8, 1>, 32,
8, 1, 1};
if (s != 1 && fx <= 2 && fy <= 2)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 2, 2, 2, 32, 8, 1>, 32,
8, 1, 1};
// channels_last
if (s == 1 && fx <= 24 && fy <= 24)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 2, 24, 24, 16, 16, 1>,
16, 16, 1, 1};
if (s == 1 && fx <= 16 && fy <= 16)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 2, 16, 16, 16, 16, 1>,
16, 16, 1, 1};
if (s == 1 && fx <= 8 && fy <= 8)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 2, 8, 8, 8, 8, 8>, 8,
8, 8, 1};
if (s == 1 && fx <= 6 && fy <= 6)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 2, 6, 6, 8, 8, 8>, 8,
8, 8, 1};
if (s == 1 && fx <= 4 && fy <= 4)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 2, 4, 4, 8, 8, 8>, 8,
8, 8, 1};
if (s == 1 && fx <= 2 && fy <= 2)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 2, 2, 2, 8, 8, 8>, 8,
8, 8, 1};
}
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) {
// contiguous
if (s != 1 && fx <= 24 && fy <= 1)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 1, 24, 1, 64, 8, 1>,
64, 8, 1, 1};
if (s != 1 && fx <= 16 && fy <= 1)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 1, 16, 1, 64, 8, 1>,
64, 8, 1, 1};
if (s != 1 && fx <= 8 && fy <= 1)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 1, 8, 1, 64, 8, 1>, 64,
8, 1, 1};
// channels_last
if (s == 1 && fx <= 24 && fy <= 1)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 1, 24, 1, 64, 1, 8>,
64, 1, 8, 1};
if (s == 1 && fx <= 16 && fy <= 1)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 1, 16, 1, 64, 1, 8>,
64, 1, 8, 1};
if (s == 1 && fx <= 8 && fy <= 1)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 1, 8, 1, 64, 1, 8>, 64,
1, 8, 1};
}
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) {
// contiguous
if (s != 1 && fx <= 1 && fy <= 24)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 2, 1, 24, 32, 16, 1>,
32, 16, 1, 1};
if (s != 1 && fx <= 1 && fy <= 16)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 2, 1, 16, 32, 16, 1>,
32, 16, 1, 1};
if (s != 1 && fx <= 1 && fy <= 8)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 2, 1, 8, 32, 16, 1>,
32, 16, 1, 1};
// channels_last
if (s == 1 && fx <= 1 && fy <= 24)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 2, 1, 24, 1, 64, 8>, 1,
64, 8, 1};
if (s == 1 && fx <= 1 && fy <= 16)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 2, 1, 16, 1, 64, 8>, 1,
64, 8, 1};
if (s == 1 && fx <= 1 && fy <= 8)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 2, 1, 8, 1, 64, 8>, 1,
64, 8, 1};
}
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
p.kernel_h <= 2 && p.kernel_w <= 2) {
mode = 6;
tile_out_h = 8;
tile_out_w = 32;
// 4x upsampling.
if (p.up.x == 4 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1) {
// contiguous
if (s != 1 && fx <= 48 && fy <= 48)
spec = {(void *)upfirdn2d_kernel_small<T, 4, 4, 1, 1, 48, 48, 64, 32, 1>,
64, 32, 1, 1};
if (s != 1 && fx <= 32 && fy <= 32)
spec = {(void *)upfirdn2d_kernel_small<T, 4, 4, 1, 1, 32, 32, 64, 32, 1>,
64, 32, 1, 1};
// channels_last
if (s == 1 && fx <= 48 && fy <= 48)
spec = {(void *)upfirdn2d_kernel_small<T, 4, 4, 1, 1, 48, 48, 32, 32, 1>,
32, 32, 1, 1};
if (s == 1 && fx <= 32 && fy <= 32)
spec = {(void *)upfirdn2d_kernel_small<T, 4, 4, 1, 1, 32, 32, 32, 32, 1>,
32, 32, 1, 1};
}
if (p.up.x == 4 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) {
// contiguous
if (s != 1 && fx <= 48 && fy <= 1)
spec = {(void *)upfirdn2d_kernel_small<T, 4, 1, 1, 1, 48, 1, 128, 8, 1>,
128, 8, 1, 1};
if (s != 1 && fx <= 32 && fy <= 1)
spec = {(void *)upfirdn2d_kernel_small<T, 4, 1, 1, 1, 32, 1, 128, 8, 1>,
128, 8, 1, 1};
// channels_last
if (s == 1 && fx <= 48 && fy <= 1)
spec = {(void *)upfirdn2d_kernel_small<T, 4, 1, 1, 1, 48, 1, 128, 1, 16>,
128, 1, 16, 1};
if (s == 1 && fx <= 32 && fy <= 1)
spec = {(void *)upfirdn2d_kernel_small<T, 4, 1, 1, 1, 32, 1, 128, 1, 16>,
128, 1, 16, 1};
}
if (p.up.x == 1 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1) {
// contiguous
if (s != 1 && fx <= 1 && fy <= 48)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 4, 1, 1, 1, 48, 32, 32, 1>,
32, 32, 1, 1};
if (s != 1 && fx <= 1 && fy <= 32)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 4, 1, 1, 1, 32, 32, 32, 1>,
32, 32, 1, 1};
// channels_last
if (s == 1 && fx <= 1 && fy <= 48)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 4, 1, 1, 1, 48, 1, 128, 16>,
1, 128, 16, 1};
if (s == 1 && fx <= 1 && fy <= 32)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 4, 1, 1, 1, 32, 1, 128, 16>,
1, 128, 16, 1};
}
dim3 block_size;
dim3 grid_size;
if (tile_out_h > 0 && tile_out_w > 0) {
p.loop_major = (p.major_dim - 1) / 16384 + 1;
p.loop_x = 1;
block_size = dim3(32 * 8, 1, 1);
grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
(p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
(p.major_dim - 1) / p.loop_major + 1);
} else {
p.loop_major = (p.major_dim - 1) / 16384 + 1;
p.loop_x = 4;
block_size = dim3(4, 32, 1);
grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
(p.out_w - 1) / (p.loop_x * block_size.y) + 1,
(p.major_dim - 1) / p.loop_major + 1);
// 4x downsampling (inefficient).
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 4 && p.down.y == 1) {
// contiguous
if (s != 1 && fx <= 48 && fy <= 1)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 4, 1, 48, 1, 32, 8, 1>,
32, 8, 1, 1};
if (s != 1 && fx <= 32 && fy <= 1)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 4, 1, 32, 1, 32, 8, 1>,
32, 8, 1, 1};
// channels_last
if (s == 1 && fx <= 48 && fy <= 1)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 4, 1, 48, 1, 32, 1, 8>,
32, 1, 8, 1};
if (s == 1 && fx <= 32 && fy <= 1)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 4, 1, 32, 1, 32, 1, 8>,
32, 1, 8, 1};
}
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 4) {
// contiguous
if (s != 1 && fx <= 1 && fy <= 48)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 4, 1, 48, 32, 8, 1>,
32, 8, 1, 1};
if (s != 1 && fx <= 1 && fy <= 32)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 4, 1, 32, 32, 8, 1>,
32, 8, 1, 1};
// channels_last
if (s == 1 && fx <= 1 && fy <= 48)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 4, 1, 48, 1, 32, 8>, 1,
32, 8, 1};
if (s == 1 && fx <= 1 && fy <= 32)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 4, 1, 32, 1, 32, 8>, 1,
32, 8, 1};
}
return spec;
}
//------------------------------------------------------------------------
// Template specializations.
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<double>(
const upfirdn2d_kernel_params &p);
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<float>(
const upfirdn2d_kernel_params &p);
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<c10::Half>(
const upfirdn2d_kernel_params &p);
//------------------------------------------------------------------------
//------------------------------------------------------------------------
torch::Tensor upfirdn2d_op(torch::Tensor x, torch::Tensor f, int upx, int upy,
int downx, int downy, int padx0, int padx1,
int pady0, int pady1, bool flip, float gain) {
// Validate arguments.
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
TORCH_CHECK(f.device() == x.device(),
"f must reside on the same device as x");
TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
TORCH_CHECK(x.numel() > 0, "x has zero size");
TORCH_CHECK(f.numel() > 0, "f has zero size");
TORCH_CHECK(x.dim() == 4, "x must be rank 4");
TORCH_CHECK(f.dim() == 2, "f must be rank 2");
TORCH_CHECK((x.size(0) - 1) * x.stride(0) + (x.size(1) - 1) * x.stride(1) +
(x.size(2) - 1) * x.stride(2) +
(x.size(3) - 1) * x.stride(3) <=
INT_MAX,
"x memory footprint is too large");
TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
TORCH_CHECK(downx >= 1 && downy >= 1,
"downsampling factor must be at least 1");
// Create output tensor.
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
int outW =
((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
int outH =
((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW},
x.options(), x.suggest_memory_format());
TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
TORCH_CHECK((y.size(0) - 1) * y.stride(0) + (y.size(1) - 1) * y.stride(1) +
(y.size(2) - 1) * y.stride(2) +
(y.size(3) - 1) * y.stride(3) <=
INT_MAX,
"output memory footprint is too large");
// Initialize CUDA kernel parameters.
upfirdn2d_kernel_params p;
p.x = x.data_ptr();
p.f = f.data_ptr<float>();
p.y = y.data_ptr();
p.up = make_int2(upx, upy);
p.down = make_int2(downx, downy);
p.pad0 = make_int2(padx0, pady0);
p.flip = (flip) ? 1 : 0;
p.gain = gain;
p.inSize =
make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1),
(int)x.stride(0));
p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
p.outSize =
make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1),
(int)y.stride(0));
p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
// Choose CUDA kernel.
upfirdn2d_kernel_spec spec;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
switch (mode) {
case 1:
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64>
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
case 2:
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64>
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
case 3:
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64>
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
case 4:
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64>
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
case 5:
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
case 6:
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
default:
upfirdn2d_kernel_large<scalar_t><<<grid_size, block_size, 0, stream>>>(
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
}
spec = choose_upfirdn2d_kernel<scalar_t>(p);
});
return out;
// Set looping options.
p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
p.loopMinor = spec.loopMinor;
p.loopX = spec.loopX;
p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
// Compute grid size.
dim3 blockSize, gridSize;
if (spec.tileOutW < 0) // large
{
blockSize = dim3(4, 32, 1);
gridSize =
dim3(((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
(p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, p.launchMajor);
} else // small
{
blockSize = dim3(256, 1, 1);
gridSize =
dim3(((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
(p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, p.launchMajor);
}
// Launch CUDA kernel.
void *args[] = {&p};
AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0,
at::cuda::getCurrentCUDAStream()));
return y;
}
#include "pytorch_cpp_helper.hpp"
#include "pytorch_device_registry.hpp"
std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu_op_impl(
torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b,
torch::Tensor si, int up, int down, int px0, int px1, int py0, int py1,
int sx, int sy, float gain, float slope, float clamp, bool flip_filters,
bool writeSigns) {
return DISPATCH_DEVICE_IMPL(filtered_lrelu_op_impl, x, fu, fd, b, si, up,
down, px0, px1, py0, py1, sx, sy, gain, slope,
clamp, flip_filters, writeSigns);
}
std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu(
torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b,
torch::Tensor si, int up, int down, int px0, int px1, int py0, int py1,
int sx, int sy, float gain, float slope, float clamp, bool flip_filters,
bool writeSigns) {
return filtered_lrelu_op_impl(x, fu, fd, b, si, up, down, px0, px1, py0, py1,
sx, sy, gain, slope, clamp, flip_filters,
writeSigns);
}
torch::Tensor filtered_lrelu_act_op_impl(torch::Tensor x, torch::Tensor si,
int sx, int sy, float gain,
float slope, float clamp,
bool writeSigns) {
return DISPATCH_DEVICE_IMPL(filtered_lrelu_act_op_impl, x, si, sx, sy, gain,
slope, clamp, writeSigns);
}
torch::Tensor filtered_lrelu_act_(torch::Tensor x, torch::Tensor si, int sx,
int sy, float gain, float slope, float clamp,
bool writeSigns) {
return filtered_lrelu_act_op_impl(x, si, sx, sy, gain, slope, clamp,
writeSigns);
}
......@@ -312,9 +312,9 @@ Tensor nms_rotated(const Tensor dets, const Tensor scores, const Tensor order,
const Tensor dets_sorted, const Tensor labels,
const float iou_threshold, const int multi_label);
Tensor upfirdn2d(const Tensor &input, const Tensor &kernel, int up_x, int up_y,
int down_x, int down_y, int pad_x0, int pad_x1, int pad_y0,
int pad_y1);
Tensor upfirdn2d(torch::Tensor input, torch::Tensor filter, int upx, int upy,
int downx, int downy, int padx0, int padx1, int pady0,
int pady1, bool flip, float gain);
Tensor fused_bias_leakyrelu(const Tensor &input, const Tensor &bias,
const Tensor &refer, int act, int grad, float alpha,
......@@ -439,6 +439,20 @@ void chamfer_distance_backward(const Tensor xyz1, const Tensor xyz2,
Tensor graddist2, Tensor gradxyz1,
Tensor gradxyz2);
Tensor bias_act(const Tensor &input, const Tensor &bias, const Tensor &xref,
const Tensor &yref, const Tensor &dy, int grad, int dim,
int act, float alpha, float gain, float clamp);
std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu(
torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b,
torch::Tensor si, int up, int down, int px0, int px1, int py0, int py1,
int sx, int sy, float gain, float slope, float clamp, bool flip_filters,
bool writeSigns);
torch::Tensor filtered_lrelu_act_(torch::Tensor x, torch::Tensor si, int sx,
int sy, float gain, float slope, float clamp,
bool writeSigns);
void box_iou_quadri(const Tensor boxes1, const Tensor boxes2, Tensor ious,
const int mode_flag, const bool aligned);
......@@ -458,9 +472,9 @@ void bezier_align_backward(Tensor grad_output, Tensor rois, Tensor grad_input,
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)", py::arg("input"),
py::arg("kernel"), py::arg("up_x"), py::arg("up_y"), py::arg("down_x"),
py::arg("down_y"), py::arg("pad_x0"), py::arg("pad_x1"),
py::arg("pad_y0"), py::arg("pad_y1"));
py::arg("filter"), py::arg("upx"), py::arg("upy"), py::arg("downx"),
py::arg("downy"), py::arg("padx0"), py::arg("padx1"), py::arg("pady0"),
py::arg("pady1"), py::arg("flip"), py::arg("gain"));
m.def("fused_bias_leakyrelu", &fused_bias_leakyrelu,
"fused_bias_leakyrelu (CUDA)", py::arg("input"), py::arg("bias"),
py::arg("empty"), py::arg("act"), py::arg("grad"), py::arg("alpha"),
......@@ -902,6 +916,20 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("input"), py::arg("rois"), py::arg("grad_rois"),
py::arg("pooled_height"), py::arg("pooled_width"),
py::arg("spatial_scale"));
m.def("bias_act", &bias_act, "bias_act (CUDA)", py::arg("input"),
py::arg("bias"), py::arg("xref"), py::arg("yref"), py::arg("dy"),
py::arg("grad"), py::arg("dim"), py::arg("act"), py::arg("alpha"),
py::arg("gain"), py::arg("clamp"));
m.def("filtered_lrelu", &filtered_lrelu, "filtered_lrelu (CUDA)",
py::arg("x"), py::arg("fu"), py::arg("fd"), py::arg("b"), py::arg("si"),
py::arg("up"), py::arg("down"), py::arg("px0"), py::arg("px1"),
py::arg("py0"), py::arg("py1"), py::arg("sx"), py::arg("sy"),
py::arg("gain"), py::arg("slope"), py::arg("clamp"),
py::arg("flip_filters"), py::arg("writeSigns"));
m.def("filtered_lrelu_act_", &filtered_lrelu_act_,
"filtered_lrelu_act_ (CUDA)", py::arg("x"), py::arg("si"),
py::arg("sx"), py::arg("sy"), py::arg("gain"), py::arg("slope"),
py::arg("clamp"), py::arg("writeSigns"));
m.def("box_iou_quadri", &box_iou_quadri, "IoU for quadrilateral boxes",
py::arg("boxes1"), py::arg("boxes2"), py::arg("ious"),
py::arg("mode_flag"), py::arg("aligned"));
......
......@@ -102,17 +102,17 @@ THE POSSIBILITY OF SUCH DAMAGES.
#include "pytorch_cpp_helper.hpp"
#include "pytorch_device_registry.hpp"
torch::Tensor upfirdn2d_op_impl(const torch::Tensor& input,
const torch::Tensor& kernel, int up_x, int up_y,
int down_x, int down_y, int pad_x0, int pad_x1,
int pad_y0, int pad_y1) {
return DISPATCH_DEVICE_IMPL(upfirdn2d_op_impl, input, kernel, up_x, up_y,
down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
torch::Tensor upfirdn2d_op_impl(torch::Tensor input, torch::Tensor filter,
int upx, int upy, int downx, int downy,
int padx0, int padx1, int pady0, int pady1,
bool flip, float gain) {
return DISPATCH_DEVICE_IMPL(upfirdn2d_op_impl, input, filter, upx, upy, downx,
downy, padx0, padx1, pady0, pady1, flip, gain);
}
torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
int up_x, int up_y, int down_x, int down_y, int pad_x0,
int pad_x1, int pad_y0, int pad_y1) {
return upfirdn2d_op_impl(input, kernel, up_x, up_y, down_x, down_y, pad_x0,
pad_x1, pad_y0, pad_y1);
torch::Tensor upfirdn2d(torch::Tensor input, torch::Tensor filter, int upx,
int upy, int downx, int downy, int padx0, int padx1,
int pady0, int pady1, bool flip, float gain) {
return upfirdn2d_op_impl(input, filter, upx, upy, downx, downy, padx0, padx1,
pady0, pady1, flip, gain);
}
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
# source: https://github.com/NVlabs/stylegan3/blob/main/torch_utils/ops/filtered_lrelu.py # noqa
import warnings
from typing import Dict, Optional, Union
import numpy as np
import torch
from ..utils import ext_loader
from .bias_act import bias_act
from .upfirdn2d import _get_filter_size, _parse_padding, upfirdn2d
ext_module = ext_loader.load_ext('_ext',
['filtered_lrelu', 'filtered_lrelu_act_'])
_plugin = None
def filtered_lrelu(input: torch.Tensor,
filter_up: Optional[torch.Tensor] = None,
filter_down: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
up: int = 1,
down: int = 1,
padding: int = 0,
gain: float = np.sqrt(2),
slope: float = 0.2,
clamp: Optional[Union[float, int]] = None,
flip_filter: bool = False,
use_custom_op: bool = True):
"""Filtered leaky ReLU for a batch of 2D images.
Performs the following sequence of operations for each channel:
1. Add channel-specific bias if `bias` is provided.
2. Upsample the image by inserting N-1 zeros after each pixel (`up`).
3. Pad the image with the specified number of zeros on each side
(`padding`). Negative padding corresponds to cropping the image.
4. Convolve the image with the specified upsampling FIR filter
(`filter_up`), shrinking it so that the footprint of all output pixels
lies within the input image.
5. Multiply each value by the provided gain factor (`gain`).
6. Apply leaky ReLU activation function to each value.
7. Clamp each value between -clamp and +clamp, if `clamp` parameter is
provided.
8. Convolve the image with the specified downsampling FIR filter
(`filter_down`), shrinking it so that the footprint of all output
pixels lies within the input image.
9. Downsample the image by keeping every Nth pixel (`down`).
The fused op is considerably more efficient than performing the same
calculation using standard PyTorch ops. It supports gradients of arbitrary
order.
Args:
input (torch.Tensor): Float32/float16/float64 input tensor of the shape
`[batch_size, num_channels, in_height, in_width]`.
filter_up (torch.Tensor): Float32 upsampling FIR filter of the shape
`[filter_height, filter_width]` (non-separable), `[filter_taps]`
(separable), or `None` (identity). Defaults to None.
filter_down (torch.Tensor): Float32 downsampling FIR filter of the
shape `[filter_height, filter_width]` (non-separable),
`[filter_taps]` (separable), or `None` (identity).
Defaults to None.
bias (torch.Tensor): Bias vector, or `None` to disable. Must be
a 1D tensor of the same type as `input`. The length of vector must
match the channel dimension of `input`. Defaults to None.
up (int): Integer upsampling factor. Defaults to 1.
down (int): Integer downsampling factor. Defaults to 1.
padding (int): Padding with respect to the upsampled image. Can be a
single number or a list/tuple `[x, y]` or `[x_before, x_after,
y_before, y_after]`. Defaults to 0.
gain (float): Overall scaling factor for signal magnitude.
Defaults to np.sqrt(2).
slope (float): Slope on the negative side of leaky ReLU.
Defaults to 0.2.
clamp (Optional[Union[float, int]]): Maximum magnitude for leaky ReLU
output. Defaults to None.
flip_filter (bool): False = convolution, True = correlation.
Defaults to False.
use_custom_op (bool): Whether to use customized op.
Defaults to True.
Returns:
Tensor of the shape `[batch_size, num_channels, out_height,
out_width]`.
"""
assert isinstance(input, torch.Tensor)
if use_custom_op and input.is_cuda:
return _filtered_lrelu_cuda(
up=up,
down=down,
padding=padding,
gain=gain,
slope=slope,
clamp=clamp,
flip_filter=flip_filter).apply(input, filter_up, filter_down, bias,
None, 0, 0)
return _filtered_lrelu_ref(
input,
filter_up=filter_up,
filter_down=filter_down,
bias=bias,
up=up,
down=down,
padding=padding,
gain=gain,
slope=slope,
clamp=clamp,
flip_filter=flip_filter)
def _filtered_lrelu_ref(input: torch.Tensor,
filter_up: Optional[torch.Tensor] = None,
filter_down: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
up: int = 1,
down: int = 1,
padding: int = 0,
gain: float = np.sqrt(2),
slope: float = 0.2,
clamp: Optional[Union[float, int]] = None,
flip_filter: bool = False):
"""Slow and memory-inefficient reference implementation of
`filtered_lrelu()` using existing `upfirdn2n()` and `bias_act()` ops.
Args:
input (torch.Tensor): Float32/float16/float64 input tensor of the shape
`[batch_size, num_channels, in_height, in_width]`.
filter_up (torch.Tensor): Float32 upsampling FIR filter of the shape
`[filter_height, filter_width]` (non-separable), `[filter_taps]`
(separable), or `None` (identity). Defaults to None.
filter_down (torch.Tensor): Float32 downsampling FIR filter of the
shape `[filter_height, filter_width]` (non-separable),
`[filter_taps]` (separable), or `None` (identity).
Defaults to None.
bias (torch.Tensor): Bias vector, or `None` to disable. Must be
a 1D tensor of the same type as `input`. The length of vector must
match the channel dimension of `input`. Defaults to None.
up (int): Integer upsampling factor. Defaults to 1.
down (int): Integer downsampling factor. Defaults to 1.
padding (int): Padding with respect to the upsampled image. Can be a
single number or a list/tuple `[x, y]` or `[x_before, x_after,
y_before, y_after]`. Defaults to 0.
gain (float): Overall scaling factor for signal magnitude.
Defaults to np.sqrt(2).
slope (float): Slope on the negative side of leaky ReLU.
Defaults to 0.2.
clamp (float or int): Maximum magnitude for leaky ReLU
output. Defaults to None.
flip_filter (bool): False = convolution, True = correlation.
Defaults to False.
Returns:
Tensor of the shape `[batch_size, num_channels, out_height,
out_width]`.
"""
assert isinstance(input, torch.Tensor) and input.ndim == 4
filter_up_w, filter_up_h = _get_filter_size(filter_up)
filter_down_w, filter_down_h = _get_filter_size(filter_down)
if bias is not None:
assert isinstance(bias, torch.Tensor) and bias.dtype == input.dtype
assert isinstance(up, int) and up >= 1
assert isinstance(down, int) and down >= 1
px0, px1, py0, py1 = _parse_padding(padding)
assert gain == float(gain) and gain > 0
assert slope == float(slope) and slope >= 0
assert clamp is None or (clamp == float(clamp) and clamp >= 0)
# Calculate output size.
batch_size, channels, in_h, in_w = input.shape
in_dtype = input.dtype
out_w = (in_w * up + (px0 + px1) - (filter_up_w - 1) -
(filter_down_w - 1) + (down - 1)) // down
out_h = (in_h * up + (py0 + py1) - (filter_up_h - 1) -
(filter_down_h - 1) + (down - 1)) // down
# Compute using existing ops.
output = bias_act(input=input, bias=bias) # Apply bias.
output = upfirdn2d(
input=output,
filter=filter_up,
up=up,
padding=[px0, px1, py0, py1],
gain=up**2,
flip_filter=flip_filter) # Upsample.
output = bias_act(
input=output, act='lrelu', alpha=slope, gain=gain,
clamp=clamp) # Bias, leaky ReLU, clamp.
output = upfirdn2d(
input=output, filter=filter_down, down=down,
flip_filter=flip_filter) # Downsample.
assert output.shape == (batch_size, channels, out_h, out_w)
assert output.dtype == in_dtype
return output
_filtered_lrelu_cuda_cache: Dict = dict()
def _filtered_lrelu_cuda(up: int = 1,
down: int = 1,
padding: int = 0,
gain: float = np.sqrt(2),
slope: float = 0.2,
clamp: Optional[Union[float, int]] = None,
flip_filter: bool = False):
"""Fast CUDA implementation of `filtered_lrelu()` using custom ops.
Args:
up (int): Integer upsampling factor. Defaults to 1.
down (int): Integer downsampling factor. Defaults to 1.
padding (int): Padding with respect to the upsampled image. Can be a
single number or a list/tuple `[x, y]` or `[x_before, x_after,
y_before, y_after]`. Defaults to 0.
gain (float): Overall scaling factor for signal magnitude.
Defaults to np.sqrt(2).
slope (float): Slope on the negative side of leaky ReLU.
Defaults to 0.2.
clamp (float or int): Maximum magnitude for leaky ReLU
output. Defaults to None.
flip_filter (bool): False = convolution, True = correlation.
Defaults to False.
Returns:
Tensor of the shape `[batch_size, num_channels, out_height,
out_width]`.
"""
assert isinstance(up, int) and up >= 1
assert isinstance(down, int) and down >= 1
px0, px1, py0, py1 = _parse_padding(padding)
assert gain == float(gain) and gain > 0
gain = float(gain)
assert slope == float(slope) and slope >= 0
slope = float(slope)
assert clamp is None or (clamp == float(clamp) and clamp >= 0)
clamp = float(clamp if clamp is not None else 'inf')
# Lookup from cache.
key = (up, down, px0, px1, py0, py1, gain, slope, clamp, flip_filter)
if key in _filtered_lrelu_cuda_cache:
return _filtered_lrelu_cuda_cache[key]
# Forward op.
class FilteredLReluCuda(torch.autograd.Function):
@staticmethod
def forward(ctx, input, filter_up, filter_down, bias, si, sx, sy):
# pylint: disable=arguments-differ
assert isinstance(input, torch.Tensor) and input.ndim == 4
# Replace empty up/downsample kernels with full 1x1 kernels
# (faster than separable).
if filter_up is None:
filter_up = torch.ones([1, 1],
dtype=torch.float32,
device=input.device)
if filter_down is None:
filter_down = torch.ones([1, 1],
dtype=torch.float32,
device=input.device)
assert 1 <= filter_up.ndim <= 2
assert 1 <= filter_down.ndim <= 2
# Replace separable 1x1 kernels with full 1x1 kernels when scale
# factor is 1.
if up == 1 and filter_up.ndim == 1 and filter_up.shape[0] == 1:
filter_up = filter_up.square()[None]
if down == 1 and filter_down.ndim == 1 and filter_down.shape[
0] == 1:
filter_down = filter_down.square()[None]
# Missing sign input tensor.
if si is None:
si = torch.empty([0])
# Missing bias tensor.
if bias is None:
bias = torch.zeros([input.shape[1]],
dtype=input.dtype,
device=input.device)
# Construct internal sign tensor only if gradients are needed.
write_signs = (si.numel() == 0) and (input.requires_grad
or bias.requires_grad)
# Warn if input storage strides are not in decreasing order due to
# e.g. channels-last layout.
strides = [
input.stride(i) for i in range(input.ndim) if input.size(i) > 1
]
if any(a < b for a, b in zip(strides[:-1], strides[1:])):
warnings.warn(
'low-performance memory layout detected in filtered_lrelu '
'input', RuntimeWarning)
# Call C++/Cuda plugin if datatype is supported.
if input.dtype in [torch.float16, torch.float32]:
if torch.cuda.current_stream(
input.device) != torch.cuda.default_stream(
input.device):
warnings.warn(
'filtered_lrelu called with non-default cuda stream '
'but concurrent execution is not supported',
RuntimeWarning)
y, so, return_code = ext_module.filtered_lrelu(
input, filter_up, filter_down, bias, si.to(input.device),
up, down, px0, px1, py0, py1, sx, sy, gain, slope, clamp,
flip_filter, write_signs)
else:
return_code = -1
# No Cuda kernel found? Fall back to generic implementation.
# Still more memory efficient than the reference implementation
# because only the bit-packed sign tensor is retained for gradient
# computation.
if return_code < 0:
warnings.warn(
'filtered_lrelu called with parameters that have no '
'optimized CUDA kernel, using generic fallback',
RuntimeWarning)
y = input.add(bias.unsqueeze(-1).unsqueeze(-1)) # Add bias.
y = upfirdn2d(
input=y,
filter=filter_up,
up=up,
padding=[px0, px1, py0, py1],
gain=float(up**2),
flip_filter=flip_filter) # Upsample.
# Activation function and sign handling. Modifies y in-place.
so = ext_module.filtered_lrelu_act_(y, si.to(y.device), sx, sy,
gain, slope, clamp,
write_signs)
y = upfirdn2d(
input=y,
filter=filter_down,
down=down,
flip_filter=flip_filter) # Downsample.
# Prepare for gradient computation.
ctx.save_for_backward(filter_up, filter_down,
(si if si.numel() else so))
ctx.x_shape = input.shape
ctx.y_shape = y.shape
ctx.s_ofs = sx, sy
return y
@staticmethod
def backward(ctx, dy): # pylint: disable=arguments-differ
filter_up, filter_down, si = ctx.saved_tensors
_, _, xh, xw = ctx.x_shape
_, _, yh, yw = ctx.y_shape
sx, sy = ctx.s_ofs
dx = None # 0
dfu = None
assert not ctx.needs_input_grad[1]
dfd = None
assert not ctx.needs_input_grad[2]
db = None # 3
dsi = None
assert not ctx.needs_input_grad[4]
dsx = None
assert not ctx.needs_input_grad[5]
dsy = None
assert not ctx.needs_input_grad[6]
if ctx.needs_input_grad[0] or ctx.needs_input_grad[3]:
pp = [
(filter_up.shape[-1] - 1) + (filter_down.shape[-1] - 1) -
px0,
xw * up - yw * down + px0 - (up - 1),
(filter_up.shape[0] - 1) + (filter_down.shape[0] - 1) -
py0,
xh * up - yh * down + py0 - (up - 1),
]
gg = gain * (up**2) / (down**2)
ff = (not flip_filter)
sx = sx - (filter_up.shape[-1] - 1) + px0
sy = sy - (filter_up.shape[0] - 1) + py0
dx = _filtered_lrelu_cuda(
up=down,
down=up,
padding=pp,
gain=gg,
slope=slope,
clamp=None,
flip_filter=ff).apply(dy, filter_down, filter_up, None, si,
sx, sy)
if ctx.needs_input_grad[3]:
db = dx.sum([0, 2, 3])
return dx, dfu, dfd, db, dsi, dsx, dsy
# Add to cache.
_filtered_lrelu_cuda_cache[key] = FilteredLReluCuda
return FilteredLReluCuda
# modified from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py # noqa:E501
# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
# NVIDIA Source Code License for StyleGAN2 with Adaptive Discriminator
# Augmentation (ADA)
# =======================================================================
# 1. Definitions
# "Licensor" means any person or entity that distributes its Work.
# "Software" means the original work of authorship made available under
# this License.
# "Work" means the Software and any additions to or derivative works of
# the Software that are made available under this License.
# The terms "reproduce," "reproduction," "derivative works," and
# "distribution" have the meaning as provided under U.S. copyright law;
# provided, however, that for the purposes of this License, derivative
# works shall not include works that remain separable from, or merely
# link (or bind by name) to the interfaces of, the Work.
# Works, including the Software, are "made available" under this License
# by including in or with the Work either (a) a copyright notice
# referencing the applicability of this License to the Work, or (b) a
# copy of this License.
# 2. License Grants
# 2.1 Copyright Grant. Subject to the terms and conditions of this
# License, each Licensor grants to you a perpetual, worldwide,
# non-exclusive, royalty-free, copyright license to reproduce,
# prepare derivative works of, publicly display, publicly perform,
# sublicense and distribute its Work and any resulting derivative
# works in any form.
# 3. Limitations
# 3.1 Redistribution. You may reproduce or distribute the Work only
# if (a) you do so under this License, (b) you include a complete
# copy of this License with your distribution, and (c) you retain
# without modification any copyright, patent, trademark, or
# attribution notices that are present in the Work.
# 3.2 Derivative Works. You may specify that additional or different
# terms apply to the use, reproduction, and distribution of your
# derivative works of the Work ("Your Terms") only if (a) Your Terms
# provide that the use limitation in Section 3.3 applies to your
# derivative works, and (b) you identify the specific derivative
# works that are subject to Your Terms. Notwithstanding Your Terms,
# this License (including the redistribution requirements in Section
# 3.1) will continue to apply to the Work itself.
# 3.3 Use Limitation. The Work and any derivative works thereof only
# may be used or intended for use non-commercially. Notwithstanding
# the foregoing, NVIDIA and its affiliates may use the Work and any
# derivative works commercially. As used herein, "non-commercially"
# means for research or evaluation purposes only.
# 3.4 Patent Claims. If you bring or threaten to bring a patent claim
# against any Licensor (including any claim, cross-claim or
# counterclaim in a lawsuit) to enforce any patents that you allege
# are infringed by any Work, then your rights under this License from
# such Licensor (including the grant in Section 2.1) will terminate
# immediately.
# 3.5 Trademarks. This License does not grant any rights to use any
# Licensor’s or its affiliates’ names, logos, or trademarks, except
# as necessary to reproduce the notices described in this License.
# 3.6 Termination. If you violate any term of this License, then your
# rights under this License (including the grant in Section 2.1) will
# terminate immediately.
# 4. Disclaimer of Warranty.
# THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
# NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
# THIS LICENSE.
# 5. Limitation of Liability.
# EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
# THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
# SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
# INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
# OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
# (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
# LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
# COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
# THE POSSIBILITY OF SUCH DAMAGES.
# =======================================================================
from typing import Any, List, Tuple, Union
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
# source: https://github.com/NVlabs/stylegan3/blob/main/torch_utils/ops/upfirdn2d.py # noqa
"""Custom PyTorch ops for efficient resampling of 2D images."""
from typing import Dict, List, Union
import torch
from mmengine.utils import to_2tuple
from torch.autograd import Function
from torch.nn import functional as F
from ..utils import ext_loader
from .conv2d_gradfix import conv2d
ext_module = ext_loader.load_ext('_ext', ['upfirdn2d'])
def _parse_scaling(scaling):
"""parse scaling into list [x, y]"""
if isinstance(scaling, int):
scaling = [scaling, scaling]
assert isinstance(scaling, (list, tuple))
assert all(isinstance(x, int) for x in scaling)
sx, sy = scaling
assert sx >= 1 and sy >= 1
return sx, sy
def _parse_padding(padding):
"""parse padding into list [padx0, padx1, pady0, pady1]"""
if isinstance(padding, int):
padding = [padding, padding]
assert isinstance(padding, (list, tuple))
assert all(isinstance(x, int) for x in padding)
if len(padding) == 2:
padx, pady = padding
padding = [padx, padx, pady, pady]
padx0, padx1, pady0, pady1 = padding
return padx0, padx1, pady0, pady1
def _get_filter_size(filter):
"""get width and height of filter kernel."""
if filter is None:
return 1, 1
assert isinstance(filter, torch.Tensor) and filter.ndim in [1, 2]
fw = filter.shape[-1]
fh = filter.shape[0]
fw = int(fw)
fh = int(fh)
assert fw >= 1 and fh >= 1
return fw, fh
def upfirdn2d(input: torch.Tensor,
filter: torch.Tensor,
up: int = 1,
down: int = 1,
padding: Union[int, List[int]] = 0,
flip_filter: bool = False,
gain: Union[float, int] = 1,
use_custom_op: bool = True):
"""Pad, upsample, filter, and downsample a batch of 2D images.
Performs the following sequence of operations for each channel:
1. Upsample the image by inserting N-1 zeros after each pixel (`up`).
2. Pad the image with the specified number of zeros on each side
(`padding`). Negative padding corresponds to cropping the image.
3. Convolve the image with the specified 2D FIR filter (`f`),
shrinking it so that the footprint of all output pixels lies within
the input image.
4. Downsample the image by keeping every Nth pixel (`down`).
This sequence of operations bears close resemblance to
scipy.signal.upfirdn().
The fused op is considerably more efficient than performing the same
calculation using standard PyTorch ops. It supports gradients of arbitrary
order.
upfirdn2d_ext = ext_loader.load_ext('_ext', ['upfirdn2d'])
class UpFirDn2dBackward(Function):
@staticmethod
def forward(ctx: Any, grad_output: torch.Tensor, kernel: torch.Tensor,
grad_kernel: torch.Tensor, up: tuple, down: tuple, pad: tuple,
g_pad: tuple, in_size: Union[List, Tuple],
out_size: Union[List, Tuple]) -> torch.Tensor:
up_x, up_y = up
down_x, down_y = down
g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
grad_input = upfirdn2d_ext.upfirdn2d(
grad_output,
grad_kernel,
up_x=down_x,
up_y=down_y,
down_x=up_x,
down_y=up_y,
pad_x0=g_pad_x0,
pad_x1=g_pad_x1,
pad_y0=g_pad_y0,
pad_y1=g_pad_y1)
grad_input = grad_input.view(in_size[0], in_size[1], in_size[2],
in_size[3])
ctx.save_for_backward(kernel)
pad_x0, pad_x1, pad_y0, pad_y1 = pad
ctx.up_x = up_x
ctx.up_y = up_y
ctx.down_x = down_x
ctx.down_y = down_y
ctx.pad_x0 = pad_x0
ctx.pad_x1 = pad_x1
ctx.pad_y0 = pad_y0
ctx.pad_y1 = pad_y1
ctx.in_size = in_size
ctx.out_size = out_size
return grad_input
@staticmethod
def backward(ctx: Any, gradgrad_input: torch.Tensor) -> tuple:
kernel, = ctx.saved_tensors
gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2],
ctx.in_size[3], 1)
gradgrad_out = upfirdn2d_ext.upfirdn2d(
gradgrad_input,
kernel,
up_x=ctx.up_x,
up_y=ctx.up_y,
down_x=ctx.down_x,
down_y=ctx.down_y,
pad_x0=ctx.pad_x0,
pad_x1=ctx.pad_x1,
pad_y0=ctx.pad_y0,
pad_y1=ctx.pad_y1)
# gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0],
# ctx.out_size[1], ctx.in_size[3])
gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1],
ctx.out_size[0], ctx.out_size[1])
return gradgrad_out, None, None, None, None, None, None, None, None
class UpFirDn2d(Function):
@staticmethod
def forward(ctx: Any, input: torch.Tensor, kernel: torch.Tensor, up: tuple,
down: tuple, pad: tuple) -> torch.Tensor:
up_x, up_y = up
down_x, down_y = down
pad_x0, pad_x1, pad_y0, pad_y1 = pad
kernel_h, kernel_w = kernel.shape
batch, channel, in_h, in_w = input.shape
ctx.in_size = input.shape
input = input.reshape(-1, in_h, in_w, 1)
ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
ctx.out_size = (out_h, out_w)
ctx.up = (up_x, up_y)
ctx.down = (down_x, down_y)
ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
g_pad_x0 = kernel_w - pad_x0 - 1
g_pad_y0 = kernel_h - pad_y0 - 1
g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
out = upfirdn2d_ext.upfirdn2d(
input,
kernel,
up_x=up_x,
up_y=up_y,
down_x=down_x,
down_y=down_y,
pad_x0=pad_x0,
pad_x1=pad_x1,
pad_y0=pad_y0,
pad_y1=pad_y1)
# out = out.view(major, out_h, out_w, minor)
out = out.view(-1, channel, out_h, out_w)
return out
@staticmethod
def backward(ctx: Any, grad_output: torch.Tensor) -> tuple:
kernel, grad_kernel = ctx.saved_tensors
grad_input = UpFirDn2dBackward.apply(
grad_output,
kernel,
grad_kernel,
ctx.up,
ctx.down,
ctx.pad,
ctx.g_pad,
ctx.in_size,
ctx.out_size,
)
return grad_input, None, None, None, None
def upfirdn2d(
input: torch.Tensor,
kernel: torch.Tensor,
up: Union[int, tuple] = 1,
down: Union[int, tuple] = 1,
pad: tuple = (0, 0)) -> torch.Tensor: # noqa E125
"""UpFRIDn for 2d features.
UpFIRDn is short for upsample, apply FIR filter and downsample. More
details can be found in:
https://www.mathworks.com/help/signal/ref/upfirdn.html
Args:
input (torch.Tensor): Float32/float64/float16 input tensor of the shape
`[batch_size, num_channels, in_height, in_width]`.
filter (torch.Tensor): Float32 FIR filter of the shape `[filter_height,
filter_width]` (non-separable), `[filter_taps]` (separable), or
`None` (identity).
up (int): Integer upsampling factor. Can be a single int or a
list/tuple `[x, y]`. Defaults to 1.
down (int): Integer downsampling factor. Can be a single int
or a list/tuple `[x, y]`. Defaults to 1.
padding (int | tuple[int]): Padding with respect to the upsampled
image. Can be a single number or a list/tuple `[x, y]` or
`[x_before, x_after, y_before, y_after]`. Defaults to 0.
flip_filter (bool): False = convolution, True = correlation.
Defaults to False.
gain (int): Overall scaling factor for signal magnitude.
Defaults to 1.
use_custom_op (bool): Whether to use customized op.
Defaults to True.
Returns:
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`
"""
assert isinstance(input, torch.Tensor)
if use_custom_op and input.device.type == 'cuda':
return _upfirdn2d_cuda(
up=up,
down=down,
padding=padding,
flip_filter=flip_filter,
gain=gain).apply(input, filter)
return _upfirdn2d_ref(
input,
filter,
up=up,
down=down,
padding=padding,
flip_filter=flip_filter,
gain=gain)
def _upfirdn2d_ref(input: torch.Tensor,
filter: torch.Tensor,
up: int = 1,
down: int = 1,
padding: Union[int, List[int]] = 0,
flip_filter: bool = False,
gain: Union[float, int] = 1):
"""Slow reference implementation of `upfirdn2d()` using standard PyTorch
ops.
Args:
input (torch.Tensor): Float32/float64/float16 input tensor of the shape
`[batch_size, num_channels, in_height, in_width]`.
filter (torch.Tensor): Float32 FIR filter of the shape `[filter_height,
filter_width]` (non-separable), `[filter_taps]` (separable), or
`None` (identity).
up (int): Integer upsampling factor. Can be a single int or a
list/tuple `[x, y]`. Defaults to 1.
down (int): Integer downsampling factor. Can be a single int
or a list/tuple `[x, y]`. Defaults to 1.
padding (int | tuple[int]): Padding with respect to the upsampled
image. Can be a single number or a list/tuple `[x, y]` or
`[x_before, x_after, y_before, y_after]`. Defaults to 0.
flip_filter (bool): False = convolution, True = correlation.
Defaults to False.
gain (int): Overall scaling factor for signal magnitude.
Defaults to 1.
Returns:
torch.Tensor: Tensor of the shape `[batch_size, num_channels,
out_height, out_width]`.
"""
# Validate arguments.
assert isinstance(input, torch.Tensor) and input.ndim == 4
if filter is None:
filter = torch.ones([1, 1], dtype=torch.float32, device=input.device)
assert isinstance(filter, torch.Tensor) and filter.ndim in [1, 2]
assert filter.dtype == torch.float32 and not filter.requires_grad
batch_size, num_channels, in_height, in_width = input.shape
upx, upy = _parse_scaling(up)
downx, downy = _parse_scaling(down)
padx0, padx1, pady0, pady1 = _parse_padding(padding)
# Check that upsampled buffer is not smaller than the filter.
upW = in_width * upx + padx0 + padx1
upH = in_height * upy + pady0 + pady1
assert upW >= filter.shape[-1] and upH >= filter.shape[0]
# Upsample by inserting zeros.
x = input.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
# Pad or crop.
x = torch.nn.functional.pad(
x, [max(padx0, 0),
max(padx1, 0),
max(pady0, 0),
max(pady1, 0)])
x = x[:, :,
max(-pady0, 0):x.shape[2] - max(-pady1, 0),
max(-padx0, 0):x.shape[3] - max(-padx1, 0)]
# Setup filter.
filter = filter * (gain**(filter.ndim / 2))
filter = filter.to(x.dtype)
if not flip_filter:
filter = filter.flip(list(range(filter.ndim)))
# Convolve with the filter.
filter = filter[None, None].repeat([num_channels, 1] + [1] * filter.ndim)
if filter.ndim == 4:
x = conv2d(input=x, weight=filter, groups=num_channels)
else:
x = conv2d(input=x, weight=filter.unsqueeze(2), groups=num_channels)
x = conv2d(input=x, weight=filter.unsqueeze(3), groups=num_channels)
# Downsample by throwing away pixels.
x = x[:, :, ::downy, ::downx]
return x
_upfirdn2d_cuda_cache: Dict = dict()
def _upfirdn2d_cuda(up: int = 1,
down: int = 1,
padding: Union[int, List[int]] = 0,
flip_filter: bool = False,
gain: Union[float, int] = 1):
"""Fast CUDA implementation of `upfirdn2d()` using custom ops.
Args:
input (torch.Tensor): Tensor with shape of (n, c, h, w).
kernel (torch.Tensor): Filter kernel.
up (int | tuple[int], optional): Upsampling factor. If given a number,
we will use this factor for the both height and width side.
up (int): Integer upsampling factor. Can be a single int or a
list/tuple `[x, y]`. Defaults to 1.
down (int): Integer downsampling factor. Can be a single int
or a list/tuple `[x, y]`. Defaults to 1.
padding (int | tuple[int]): Padding with respect to the upsampled
image. Can be a single number or a list/tuple `[x, y]` or
`[x_before, x_after, y_before, y_after]`. Defaults to 0.
flip_filter (bool): False = convolution, True = correlation.
Defaults to False.
gain (int): Overall scaling factor for signal magnitude.
Defaults to 1.
down (int | tuple[int], optional): Downsampling factor. If given a
number, we will use this factor for the both height and width side.
Returns:
torch.Tensor: Tensor of the shape `[batch_size, num_channels,
out_height, out_width]`
"""
# Parse arguments.
upx, upy = _parse_scaling(up)
downx, downy = _parse_scaling(down)
padx0, padx1, pady0, pady1 = _parse_padding(padding)
# Lookup from cache.
key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter,
gain)
if key in _upfirdn2d_cuda_cache:
return _upfirdn2d_cuda_cache[key]
# Forward op.
class Upfirdn2dCuda(torch.autograd.Function):
@staticmethod
def forward(ctx, x, f): # pylint: disable=arguments-differ
assert isinstance(x, torch.Tensor) and x.ndim == 4
if f is None:
f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
if f.ndim == 1 and f.shape[0] == 1:
f = f.square().unsqueeze(
0) # Convert separable-1 into full-1x1.
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
y = x
if f.ndim == 2:
y = ext_module.upfirdn2d(y, f, upx, upy, downx, downy, padx0,
padx1, pady0, pady1, flip_filter,
gain)
else:
y = ext_module.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1,
padx0, padx1, 0, 0, flip_filter, 1.0)
y = ext_module.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy,
0, 0, pady0, pady1, flip_filter, gain)
ctx.save_for_backward(f)
ctx.x_shape = x.shape
return y
@staticmethod
def backward(ctx, dy): # pylint: disable=arguments-differ
f, = ctx.saved_tensors
_, _, ih, iw = ctx.x_shape
_, _, oh, ow = dy.shape
fw, fh = _get_filter_size(f)
p = [
fw - padx0 - 1,
iw * upx - ow * downx + padx0 - upx + 1,
fh - pady0 - 1,
ih * upy - oh * downy + pady0 - upy + 1,
]
dx = None
df = None
if ctx.needs_input_grad[0]:
dx = _upfirdn2d_cuda(
up=down,
down=up,
padding=p,
flip_filter=(not flip_filter),
gain=gain).apply(dy, f)
assert not ctx.needs_input_grad[1]
return dx, df
# Add to cache.
_upfirdn2d_cuda_cache[key] = Upfirdn2dCuda
return Upfirdn2dCuda
def filter2d(input: torch.Tensor,
filter: torch.Tensor,
padding: Union[int, List[int]] = 0,
flip_filter: bool = False,
gain: Union[float, int] = 1,
use_custom_op: bool = True):
"""Filter a batch of 2D images using the given 2D FIR filter.
By default, the result is padded so that its shape matches the input.
User-specified padding is applied on top of that, with negative values
indicating cropping. Pixels outside the image are assumed to be zero.
Args:
input (torch.Tensor): Float32/float64/float16 input tensor of the shape
`[batch_size, num_channels, in_height, in_width]`.
filter (torch.Tensor): Float32 FIR filter of the shape `[filter_height,
filter_width]` (non-separable), `[filter_taps]` (separable), or
`None`.
padding (int | tuple[int]): Padding with respect to the output.
Can be a single number or a list/tuple `[x, y]` or `[x_before,
x_after, y_before, y_after]`. Defaults to 0.
flip_filter (bool): False = convolution, True = correlation.
Defaults to False.
gain (int): Overall scaling factor for signal magnitude.
Defaults to 1.
pad (tuple[int], optional): Padding for tensors, (x_pad, y_pad) or
(x_pad_0, x_pad_1, y_pad_0, y_pad_1). Defaults to (0, 0).
use_custom_op (bool): Whether to use customized op.
Defaults to True.
Returns:
torch.Tensor: Tensor after UpFIRDn.
Tensor of the shape `[batch_size, num_channels, out_height,
out_width]`.
"""
if input.device.type == 'cpu':
if len(pad) == 2:
pad = (pad[0], pad[1], pad[0], pad[1]) # type: ignore
padx0, padx1, pady0, pady1 = _parse_padding(padding)
fw, fh = _get_filter_size(filter)
p = [
padx0 + fw // 2,
padx1 + (fw - 1) // 2,
pady0 + fh // 2,
pady1 + (fh - 1) // 2,
]
return upfirdn2d(
input,
filter,
padding=p,
flip_filter=flip_filter,
gain=gain,
use_custom_op=use_custom_op)
def upsample2d(input: torch.Tensor,
filter: torch.Tensor,
up: int = 2,
padding: Union[int, List[int]] = 0,
flip_filter: bool = False,
gain: Union[float, int] = 1,
use_custom_op: bool = True):
"""Upsample a batch of 2D images using the given 2D FIR filter.
By default, the result is padded so that its shape is a multiple of the
input.
User-specified padding is applied on top of that, with negative values
indicating cropping. Pixels outside the image are assumed to be zero.
_up = to_2tuple(up)
Args:
input (torch.Tensor): Float32/float64/float16 input tensor of the shape
`[batch_size, num_channels, in_height, in_width]`.
filter (torch.Tensor): Float32 FIR filter of the shape `[filter_height,
filter_width]` (non-separable), `[filter_taps]` (separable), or
`None` (identity).
up (int): Integer upsampling factor. Can be a single int or a
list/tuple `[x, y]`. Defaults to 2.
padding (int | tuple[int]): Padding with respect to the output.
Can be a single number or a list/tuple `[x, y]` or `[x_before,
x_after, y_before, y_after]`. Defaults to 0.
flip_filter (bool): False = convolution, True = correlation. Defaults
to False.
gain (int): Overall scaling factor for signal magnitude. Defaults to 1.
use_custom_op (bool): Whether to use customized op.
Defaults to True.
_down = to_2tuple(down)
Returns:
torch.Tensor: Tensor of the shape `[batch_size, num_channels,
out_height, out_width]`
"""
upx, upy = _parse_scaling(up)
padx0, padx1, pady0, pady1 = _parse_padding(padding)
fw, fh = _get_filter_size(filter)
p = [
padx0 + (fw + upx - 1) // 2,
padx1 + (fw - upx) // 2,
pady0 + (fh + upy - 1) // 2,
pady1 + (fh - upy) // 2,
]
return upfirdn2d(
input,
filter,
up=up,
padding=p,
flip_filter=flip_filter,
gain=gain * upx * upy,
use_custom_op=use_custom_op)
def downsample2d(input: torch.Tensor,
filter: torch.Tensor,
down: int = 2,
padding: Union[int, List[int]] = 0,
flip_filter: bool = False,
gain: Union[float, int] = 1,
use_custom_op: bool = True):
"""Downsample a batch of 2D images using the given 2D FIR filter.
By default, the result is padded so that its shape is a fraction of the
input.
User-specified padding is applied on top of that, with negative values
indicating cropping. Pixels outside the image are assumed to be zero.
out = upfirdn2d_native(input, kernel, _up[0], _up[1], _down[0],
_down[1], pad[0], pad[1], pad[2], pad[3])
else:
_up = to_2tuple(up)
_down = to_2tuple(down)
if len(pad) == 4:
_pad = pad
elif len(pad) == 2:
_pad = (pad[0], pad[1], pad[0], pad[1])
out = UpFirDn2d.apply(input, kernel, _up, _down, _pad)
return out
def upfirdn2d_native(input: torch.Tensor, kernel: torch.Tensor, up_x: int,
up_y: int, down_x: int, down_y: int, pad_x0: int,
pad_x1: int, pad_y0: int, pad_y1: int) -> torch.Tensor:
_, channel, in_h, in_w = input.shape
input = input.reshape(-1, in_h, in_w, 1)
_, in_h, in_w, minor = input.shape
kernel_h, kernel_w = kernel.shape
out = input.view(-1, in_h, 1, in_w, 1, minor)
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
out = F.pad(
out,
[0, 0,
max(pad_x0, 0),
max(pad_x1, 0),
max(pad_y0, 0),
max(pad_y1, 0)])
out = out[:,
max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0),
max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0), :, ]
out = out.permute(0, 3, 1, 2)
out = out.reshape(
[-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
out = F.conv2d(out, w)
out = out.reshape(
-1,
minor,
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
)
out = out.permute(0, 2, 3, 1)
out = out[:, ::down_y, ::down_x, :]
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
return out.view(-1, channel, out_h, out_w)
Args:
input (torch.Tensor): Float32/float64/float16 input tensor of the shape
`[batch_size, num_channels, in_height, in_width]`.
filter (torch.Tensor): Float32 FIR filter of the shape `[filter_height,
filter_width]` (non-separable), `[filter_taps]` (separable), or
`None` (identity).
down (int): Integer downsampling factor. Can be a single int or a
list/tuple `[x, y]` (default: 1). Defaults to 2.
padding (int | tuple[int]): Padding with respect to the input.
Can be a single number or a list/tuple `[x, y]` or `[x_before,
x_after, y_before, y_after]`. Defaults to 0.
flip_filter (bool): False = convolution, True = correlation. Defaults
to False.
gain (int): Overall scaling factor for signal magnitude. Defaults to 1.
use_custom_op (bool): Whether to use customized op.
Defaults to True.
Returns:
torch.Tensor: Tensor of the shape `[batch_size, num_channels,
out_height, out_width]`.
"""
downx, downy = _parse_scaling(down)
padx0, padx1, pady0, pady1 = _parse_padding(padding)
fw, fh = _get_filter_size(filter)
p = [
padx0 + (fw - downx + 1) // 2,
padx1 + (fw - downx) // 2,
pady0 + (fh - downy + 1) // 2,
pady1 + (fh - downy) // 2,
]
return upfirdn2d(
input,
filter,
down=down,
padding=p,
flip_filter=flip_filter,
gain=gain,
use_custom_op=use_custom_op)
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmcv.ops import bias_act
from mmcv.ops.bias_act import EasyDict
_USING_PARROTS = True
try:
from parrots.autograd import gradcheck
except ImportError:
from torch.autograd import gradcheck, gradgradcheck
_USING_PARROTS = False
class TestBiasAct:
@classmethod
def setup_class(cls):
cls.input_tensor = torch.randn((1, 3), requires_grad=True)
cls.bias = torch.randn(3, requires_grad=True)
def test_bias_act_cpu(self):
out = bias_act(self.input_tensor, self.bias)
assert out.shape == (1, 3)
# test with different dim
input_tensor = torch.randn((1, 1, 3), requires_grad=True)
bias = torch.randn(3, requires_grad=True)
out = bias_act(input_tensor, bias, dim=2)
assert out.shape == (1, 1, 3)
# test with different act
out = bias_act(self.input_tensor, self.bias, act='relu')
assert out.shape == (1, 3)
out = bias_act(self.input_tensor, self.bias, act='lrelu')
assert out.shape == (1, 3)
out = bias_act(self.input_tensor, self.bias, act='tanh')
assert out.shape == (1, 3)
out = bias_act(self.input_tensor, self.bias, act='sigmoid')
assert out.shape == (1, 3)
out = bias_act(self.input_tensor, self.bias, act='elu')
assert out.shape == (1, 3)
out = bias_act(self.input_tensor, self.bias, act='selu')
assert out.shape == (1, 3)
out = bias_act(self.input_tensor, self.bias, act='softplus')
assert out.shape == (1, 3)
out = bias_act(self.input_tensor, self.bias, act='swish')
assert out.shape == (1, 3)
# test with different alpha
out = bias_act(self.input_tensor, self.bias, act='lrelu', alpha=0.1)
assert out.shape == (1, 3)
# test with different gain
out1 = bias_act(self.input_tensor, self.bias, act='lrelu', gain=0.2)
out2 = bias_act(self.input_tensor, self.bias, act='lrelu', gain=0.1)
assert torch.allclose(out1, out2 * 2)
# test with different clamp
out1 = bias_act(self.input_tensor, self.bias, act='lrelu', clamp=0.5)
out2 = bias_act(self.input_tensor, self.bias, act='lrelu', clamp=0.2)
assert out1.max() <= 0.5
assert out2.max() <= 0.5
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
def test_bias_act_cuda(self):
if _USING_PARROTS:
gradcheck(
bias_act, (self.input_tensor.cuda(), self.bias.cuda()),
delta=1e-4,
pt_atol=1e-3)
else:
gradcheck(
bias_act, (self.input_tensor.cuda(), self.bias.cuda()),
eps=1e-4,
atol=1e-3)
gradgradcheck(
bias_act, (self.input_tensor.cuda(), self.bias.cuda()),
eps=1e-4,
atol=1e-3)
out = bias_act(self.input_tensor.cuda(), self.bias.cuda())
assert out.shape == (1, 3)
# test with different dim
input_tensor = torch.randn((1, 1, 3), requires_grad=True).cuda()
bias = torch.randn(3, requires_grad=True).cuda()
out = bias_act(input_tensor, bias, dim=2)
assert out.shape == (1, 1, 3)
# test with different act
out = bias_act(self.input_tensor.cuda(), self.bias.cuda(), act='relu')
assert out.shape == (1, 3)
out = bias_act(self.input_tensor.cuda(), self.bias.cuda(), act='lrelu')
assert out.shape == (1, 3)
out = bias_act(self.input_tensor.cuda(), self.bias.cuda(), act='tanh')
assert out.shape == (1, 3)
out = bias_act(
self.input_tensor.cuda(), self.bias.cuda(), act='sigmoid')
assert out.shape == (1, 3)
out = bias_act(self.input_tensor.cuda(), self.bias.cuda(), act='elu')
assert out.shape == (1, 3)
out = bias_act(self.input_tensor.cuda(), self.bias.cuda(), act='selu')
assert out.shape == (1, 3)
out = bias_act(
self.input_tensor.cuda(), self.bias.cuda(), act='softplus')
assert out.shape == (1, 3)
out = bias_act(self.input_tensor.cuda(), self.bias.cuda(), act='swish')
assert out.shape == (1, 3)
# test with different alpha
out = bias_act(
self.input_tensor.cuda(), self.bias.cuda(), act='lrelu', alpha=0.1)
assert out.shape == (1, 3)
# test with different gain
out1 = bias_act(
self.input_tensor.cuda(), self.bias.cuda(), act='lrelu', gain=0.2)
out2 = bias_act(
self.input_tensor.cuda(), self.bias.cuda(), act='lrelu', gain=0.1)
assert torch.allclose(out1, out2 * 2)
# test with different clamp
out1 = bias_act(
self.input_tensor.cuda(), self.bias.cuda(), act='lrelu', clamp=0.5)
out2 = bias_act(
self.input_tensor.cuda(), self.bias.cuda(), act='lrelu', clamp=0.2)
assert out1.max() <= 0.5
assert out2.max() <= 0.5
def test_easy_dict(self):
easy_dict = EasyDict(
func=lambda x, **_: x,
def_alpha=0,
def_gain=1,
cuda_idx=1,
ref='',
has_2nd_grad=False)
_ = easy_dict.def_alpha
easy_dict.def_alpha = 1
del easy_dict.def_alpha
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
import torch.nn as nn
from torch.autograd import gradcheck, gradgradcheck
from mmcv.ops import conv2d, conv_transpose2d
class TestCond2d:
@classmethod
def setup_class(cls):
cls.input = torch.randn((1, 3, 32, 32), requires_grad=True)
cls.weight = nn.Parameter(torch.randn(1, 3, 3, 3))
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
def test_conv2d_cuda(self):
x = self.input.cuda()
weight = self.weight.cuda()
res = conv2d(x, weight, None, 1, 1)
assert res.shape == (1, 1, 32, 32)
gradcheck(conv2d, (x, weight, None, 1, 1), eps=1e-2, atol=1e-2)
gradgradcheck(conv2d, (x, weight, None, 1, 1), eps=1e-2, atol=1e-2)
class TestCond2dTansposed:
@classmethod
def setup_class(cls):
cls.input = torch.randn((1, 3, 32, 32), requires_grad=True)
cls.weight = nn.Parameter(torch.randn(3, 1, 3, 3))
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
def test_conv2d_transposed_cuda(self):
x = self.input.cuda()
weight = self.weight.cuda()
res = conv_transpose2d(x, weight, None, 1, 1)
assert res.shape == (1, 1, 32, 32)
gradcheck(
conv_transpose2d, (x, weight, None, 1, 1), eps=1e-2, atol=1e-2)
gradgradcheck(
conv_transpose2d, (x, weight, None, 1, 1), eps=1e-2, atol=1e-2)
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmcv.ops import filtered_lrelu
class TestFilteredLrelu:
@classmethod
def setup_class(cls):
cls.input_tensor = torch.randn((1, 3, 16, 16), requires_grad=True)
cls.bias = torch.randn(3, requires_grad=True)
cls.filter_up = torch.randn((2, 2))
cls.filter_down = torch.randn((2, 2))
def test_filtered_lrelu_cpu(self):
out = filtered_lrelu(self.input_tensor, bias=self.bias)
assert out.shape == (1, 3, 16, 16)
out = filtered_lrelu(
self.input_tensor,
bias=self.bias,
filter_up=self.filter_up,
filter_down=self.filter_down,
up=2,
down=2,
padding=1,
clamp=0.5)
assert out.shape == (1, 3, 16, 16)
# test with different filter_up
filter_up = torch.randn((4, 4))
out = filtered_lrelu(
self.input_tensor,
bias=self.bias,
filter_up=filter_up,
filter_down=self.filter_down,
up=2,
down=2,
padding=2,
clamp=0.5)
assert out.shape == (1, 3, 16, 16)
# test with different filter_down
filter_down = torch.randn((4, 4))
out = filtered_lrelu(
self.input_tensor,
bias=self.bias,
filter_up=self.filter_up,
filter_down=filter_down,
up=2,
down=2,
padding=2,
clamp=0.5)
assert out.shape == (1, 3, 16, 16)
# test with different b
input_tensor = torch.randn((1, 4, 16, 16), requires_grad=True)
bias = torch.randn(4, requires_grad=True)
out = filtered_lrelu(
input_tensor,
bias=bias,
filter_up=self.filter_up,
filter_down=self.filter_down,
up=2,
down=2,
padding=1,
clamp=0.5)
assert out.shape == (1, 4, 16, 16)
# test with different up
out = filtered_lrelu(
self.input_tensor,
bias=self.bias,
filter_up=self.filter_up,
filter_down=self.filter_down,
up=4,
down=2,
padding=1,
clamp=0.5)
assert out.shape == (1, 3, 32, 32)
# test with different down
out = filtered_lrelu(
self.input_tensor,
bias=self.bias,
filter_up=self.filter_up,
filter_down=self.filter_down,
up=2,
down=4,
padding=1,
clamp=0.5)
assert out.shape == (1, 3, 8, 8)
# test with different gain
out1 = filtered_lrelu(self.input_tensor, bias=self.bias, gain=0.2)
out2 = filtered_lrelu(self.input_tensor, bias=self.bias, gain=0.1)
assert torch.allclose(out1, 2 * out2)
# test with different slope
out = filtered_lrelu(self.input_tensor, bias=self.bias, slope=0.2)
assert out.shape == (1, 3, 16, 16)
# test with different clamp
out1 = filtered_lrelu(self.input_tensor, bias=self.bias, clamp=0.2)
out2 = filtered_lrelu(self.input_tensor, bias=self.bias, clamp=0.1)
assert out1.max() <= 0.2
assert out2.max() <= 0.1
# test with different flip_filter
out1 = filtered_lrelu(
self.input_tensor, bias=self.bias, flip_filter=True)
assert out.shape == (1, 3, 16, 16)
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
def test_filtered_lrelu_cuda(self):
out = filtered_lrelu(self.input_tensor.cuda(), bias=self.bias.cuda())
assert out.shape == (1, 3, 16, 16)
out = filtered_lrelu(
self.input_tensor.cuda(),
bias=self.bias.cuda(),
filter_up=self.filter_up.cuda(),
filter_down=self.filter_down.cuda(),
up=2,
down=2,
padding=1,
clamp=0.5)
assert out.shape == (1, 3, 16, 16)
# test with different filter_up
filter_up = torch.randn((4, 4))
out = filtered_lrelu(
self.input_tensor.cuda(),
bias=self.bias.cuda(),
filter_up=filter_up.cuda(),
filter_down=self.filter_down.cuda(),
up=2,
down=2,
padding=2,
clamp=0.5)
assert out.shape == (1, 3, 16, 16)
# test with different filter_down
filter_down = torch.randn((4, 4))
out = filtered_lrelu(
self.input_tensor.cuda(),
bias=self.bias.cuda(),
filter_up=self.filter_up.cuda(),
filter_down=filter_down.cuda(),
up=2,
down=2,
padding=2,
clamp=0.5)
assert out.shape == (1, 3, 16, 16)
# test with different b
input_tensor = torch.randn((1, 4, 16, 16), requires_grad=True)
bias = torch.randn(4, requires_grad=True)
out = filtered_lrelu(
input_tensor.cuda(),
bias=bias.cuda(),
filter_up=self.filter_up.cuda(),
filter_down=self.filter_down.cuda(),
up=2,
down=2,
padding=1,
clamp=0.5)
assert out.shape == (1, 4, 16, 16)
# test with different up
out = filtered_lrelu(
self.input_tensor.cuda(),
bias=self.bias.cuda(),
filter_up=self.filter_up.cuda(),
filter_down=self.filter_down.cuda(),
up=4,
down=2,
padding=1,
clamp=0.5)
assert out.shape == (1, 3, 32, 32)
# test with different down
out = filtered_lrelu(
self.input_tensor.cuda(),
bias=self.bias.cuda(),
filter_up=self.filter_up.cuda(),
filter_down=self.filter_down.cuda(),
up=2,
down=4,
padding=1,
clamp=0.5)
assert out.shape == (1, 3, 8, 8)
# test with different gain
out1 = filtered_lrelu(
self.input_tensor.cuda(), bias=self.bias.cuda(), gain=0.2)
out2 = filtered_lrelu(
self.input_tensor.cuda(), bias=self.bias.cuda(), gain=0.1)
assert torch.allclose(out1, 2 * out2)
# test with different slope
out = filtered_lrelu(
self.input_tensor.cuda(), bias=self.bias.cuda(), slope=0.2)
assert out.shape == (1, 3, 16, 16)
# test with different clamp
out1 = filtered_lrelu(
self.input_tensor.cuda(), bias=self.bias.cuda(), clamp=0.2)
out2 = filtered_lrelu(
self.input_tensor.cuda(), bias=self.bias.cuda(), clamp=0.1)
assert out1.max() <= 0.2
assert out2.max() <= 0.1
# test with different flip_filter
out1 = filtered_lrelu(
self.input_tensor.cuda(), bias=self.bias.cuda(), flip_filter=True)
assert out.shape == (1, 3, 16, 16)
......@@ -56,3 +56,29 @@ class TestUpFirDn2d:
self.input_tensor).cuda(), self.factor, 1, self.pad),
eps=1e-4,
atol=1e-3)
# test with different up
kernel = torch.randn(3, 3)
out = upfirdn2d(
self.input_tensor.cuda(), filter=kernel.cuda(), up=2, padding=1)
assert out.shape == (2, 3, 8, 8)
# test with different down
input_tensor = torch.randn(2, 3, 8, 8)
out = upfirdn2d(
input_tensor.cuda(), filter=self.kernel.cuda(), down=2, padding=1)
assert out.shape == (2, 3, 4, 4)
# test with different flip_filter
out = upfirdn2d(
self.input_tensor.cuda(),
filter=self.kernel.cuda(),
flip_filter=True)
assert out.shape == (2, 3, 1, 1)
# test with different gain
out1 = upfirdn2d(
self.input_tensor.cuda(), filter=self.kernel.cuda(), gain=0.2)
out2 = upfirdn2d(
self.input_tensor.cuda(), filter=self.kernel.cuda(), gain=0.1)
assert torch.allclose(out1, out2 * 2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment