Unverified Commit 869dbf1b authored by Yifei Yang's avatar Yifei Yang Committed by GitHub
Browse files

[Feature] Add Ops of StyleGAN3 (#2290)



* add bias_act

* support bias_act

* support filtered_lrelu

* support filtered_lrelu and upfirdn2d

* support conv2d_gradfix and fix filtered_lrelu

* fix lint

* fix lint

* fix c++ lint

* fix part comments

* fix lint

* rm redundant header

* fix upgrade pip

* fix as comment

* fix c++ lint

* fix ci

* fix-ut

* fix as comments

* add grad check

* remove redundant template

* Update mmcv/ops/bias_act.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* add typehint

* fix as comment:

* complete type hints

* fix lint

* add test for conv_gradfix

* add test for conv_gradfix

* fix lint

* modify licenses and ops.md

* add zh op md

* add torch version policy for conv2d_gradfix

* fix lint

* fix as comments

* rename impl

* rm redudant function and add ut

* fix as comment

* fix lint

* fix lint

* fix as comments

* fix lint

* fix ut

* fix as comment

* fix as comment

* fix as comment

---------
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
parent 0fb07d0e
......@@ -3,6 +3,9 @@
In this file, we list the operations with other licenses instead of Apache 2.0. Users should be careful about adopting these operations in any commercial matters.
| Operation | Files | License |
| :--------------: | :---------------------------------------------------------------------------------------------------------------------------------------------------------------: | :------------: |
| upfirdn2d | [mmcv/ops/csrc/pytorch/cuda/upfirdn2d_kernel.cu](https://github.com/open-mmlab/mmcv/blob/master/mmcv/ops/csrc/pytorch/cuda/upfirdn2d_kernel.cu) | NVIDIA License |
| fused_leaky_relu | [mmcv/ops/csrc/pytorch/cuda/fused_bias_leakyrelu_cuda.cu](https://github.com/open-mmlab/mmcv/blob/master/mmcv/ops/csrc/pytorch/cuda/fused_bias_leakyrelu_cuda.cu) | NVIDIA License |
| :--------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------: | :------------: |
| upfirdn2d | [mmcv/ops/csrc/pytorch/cuda/upfirdn2d_kernel.cu](https://github.com/open-mmlab/mmcv/tree/2.x/mmcv/ops/csrc/pytorch/cuda/upfirdn2d_kernel.cu) | NVIDIA License |
| fused_leaky_relu | [mmcv/ops/csrc/pytorch/cuda/fused_bias_leakyrelu_cuda.cu](https://github.com/open-mmlab/mmcv/tree/2.x/mmcv/ops/csrc/pytorch/cuda/fused_bias_leakyrelu_cuda.cu) | NVIDIA License |
| bias_act | [mmcv/ops/csrc/pytorch/cuda/bias_act_cuda.cu](https://github.com/open-mmlab/mmcv/tree/2.x/mmcv/ops/csrc/pytorch/cuda/bias_act_cuda.cu) | NVIDIA License |
| filtered_lrelu | [mmcv/ops/csrc/pytorch/cuda/filtered_lrelu.cu](https://github.com/open-mmlab/mmcv/tree/2.x/mmcv/ops/csrc/pytorch/cuda/filtered_lrelu.cu) | NVIDIA License |
| conv2d_gradfix | [mmcv/ops/conv2d_gradfix.py](https://github.com/open-mmlab/mmcv/tree/2.x/mmcv/ops/conv2d_gradfix.py) | NVIDIA License |
......@@ -61,3 +61,6 @@ We implement common ops used in detection, segmentation, etc.
| Voxelization | √ | √ | | | √ |
| PrRoIPool | | √ | | | |
| BezierAlign | √ | √ | | | |
| BiasAct | | √ | | | |
| FilteredLrelu | | √ | | | |
| Conv2dGradfix | | √ | | | |
......@@ -61,3 +61,6 @@ MMCV 提供了检测、分割等任务中常用的算子
| Voxelization | √ | √ | | | √ |
| PrRoIPool | | √ | | | |
| BezierAlign | √ | √ | | | |
| BiasAct | | √ | | | |
| FilteredLrelu | | √ | | | |
| Conv2dGradfix | | √ | | | |
......@@ -4,6 +4,7 @@ from .assign_score_withk import assign_score_withk
from .ball_query import ball_query
from .bbox import bbox_overlaps
from .bezier_align import BezierAlign, bezier_align
from .bias_act import bias_act
from .border_align import BorderAlign, border_align
from .box_iou_quadri import box_iou_quadri
from .box_iou_rotated import box_iou_rotated
......@@ -11,6 +12,7 @@ from .carafe import CARAFE, CARAFENaive, CARAFEPack, carafe, carafe_naive
from .cc_attention import CrissCrossAttention
from .chamfer_distance import chamfer_distance
from .contour_expand import contour_expand
from .conv2d_gradfix import conv2d, conv_transpose2d
from .convex_iou import convex_giou, convex_iou
from .corner_pool import CornerPool
from .correlation import Correlation
......@@ -22,6 +24,7 @@ from .deprecated_wrappers import ConvTranspose2d_deprecated as ConvTranspose2d
from .deprecated_wrappers import Linear_deprecated as Linear
from .deprecated_wrappers import MaxPool2d_deprecated as MaxPool2d
from .diff_iou_rotated import diff_iou_rotated_2d, diff_iou_rotated_3d
from .filtered_lrelu import filtered_lrelu
from .focal_loss import (SigmoidFocalLoss, SoftmaxFocalLoss,
sigmoid_focal_loss, softmax_focal_loss)
from .furthest_point_sample import (furthest_point_sample,
......@@ -68,7 +71,7 @@ from .sync_bn import SyncBatchNorm
from .three_interpolate import three_interpolate
from .three_nn import three_nn
from .tin_shift import TINShift, tin_shift
from .upfirdn2d import upfirdn2d
from .upfirdn2d import filter2d, upfirdn2d, upsample2d
from .voxelize import Voxelization, voxelization
__all__ = [
......@@ -103,5 +106,6 @@ __all__ = [
'points_in_boxes_cpu', 'points_in_boxes_all', 'points_in_polygons',
'min_area_polygons', 'active_rotated_filter', 'convex_iou', 'convex_giou',
'diff_iou_rotated_2d', 'diff_iou_rotated_3d', 'chamfer_distance',
'PrRoIPool', 'prroi_pool', 'BezierAlign', 'bezier_align'
'PrRoIPool', 'prroi_pool', 'bias_act', 'filtered_lrelu', 'conv2d',
'conv_transpose2d', 'filter2d', 'upsample2d', 'BezierAlign', 'bezier_align'
]
# Modified from
# https://github.com/NVlabs/stylegan3/blob/main/torch_utils/ops/bias_act.py
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
# source: https://github.com/open-mmlab/mmediting/blob/dev-1.x/mmedit/models/editors/stylegan3/stylegan3_ops/ops/bias_act.py # noqa
"""Custom PyTorch ops for efficient bias and activation."""
from typing import Any, Dict, Optional, Union
import numpy as np
import torch
from ..utils import ext_loader
ext_module = ext_loader.load_ext('_ext', ['bias_act'])
class EasyDict(dict):
"""Convenience class that behaves like a dict but allows access with the
attribute syntax."""
def __getattr__(self, name: str) -> Any:
try:
return self[name]
except KeyError:
raise AttributeError(name)
def __setattr__(self, name: str, value: Any) -> None:
self[name] = value
def __delattr__(self, name: str) -> None:
del self[name]
activation_funcs = {
'linear':
EasyDict(
func=lambda x, **_: x,
def_alpha=0,
def_gain=1,
cuda_idx=1,
ref='',
has_2nd_grad=False),
'relu':
EasyDict(
func=lambda x, **_: torch.nn.functional.relu(x),
def_alpha=0,
def_gain=np.sqrt(2),
cuda_idx=2,
ref='y',
has_2nd_grad=False),
'lrelu':
EasyDict(
func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha),
def_alpha=0.2,
def_gain=np.sqrt(2),
cuda_idx=3,
ref='y',
has_2nd_grad=False),
'tanh':
EasyDict(
func=lambda x, **_: torch.tanh(x),
def_alpha=0,
def_gain=1,
cuda_idx=4,
ref='y',
has_2nd_grad=True),
'sigmoid':
EasyDict(
func=lambda x, **_: torch.sigmoid(x),
def_alpha=0,
def_gain=1,
cuda_idx=5,
ref='y',
has_2nd_grad=True),
'elu':
EasyDict(
func=lambda x, **_: torch.nn.functional.elu(x),
def_alpha=0,
def_gain=1,
cuda_idx=6,
ref='y',
has_2nd_grad=True),
'selu':
EasyDict(
func=lambda x, **_: torch.nn.functional.selu(x),
def_alpha=0,
def_gain=1,
cuda_idx=7,
ref='y',
has_2nd_grad=True),
'softplus':
EasyDict(
func=lambda x, **_: torch.nn.functional.softplus(x),
def_alpha=0,
def_gain=1,
cuda_idx=8,
ref='y',
has_2nd_grad=True),
'swish':
EasyDict(
func=lambda x, **_: torch.sigmoid(x) * x,
def_alpha=0,
def_gain=np.sqrt(2),
cuda_idx=9,
ref='x',
has_2nd_grad=True),
}
_null_tensor = torch.empty([0])
def bias_act(input: torch.Tensor,
bias: Optional[torch.Tensor] = None,
dim: int = 1,
act: str = 'linear',
alpha: Optional[Union[float, int]] = None,
gain: Optional[float] = None,
clamp: Optional[float] = None,
use_custom_op: bool = True):
r"""Fused bias and activation function.
Adds `bias` to activation tensor `input`, and evaluates activation
function `act`, and scales the result by `gain`. Each of the steps is
optional.
In most cases, the fused op is considerably more efficient than performing
the same calculation using standard PyTorch ops. It supports first and
second order gradients, but not third order gradients.
Args:
input (torch.Tensor): Input activation tensor. Can be of any shape.
bias (torch.Tensor): Bias vector, or `None` to disable.
Must be a 1D tensor of the same type as `input`. The shape must
be known, and it must match the dimension of `input` corresponding
to `dim`. Defaults to None.
dim (int): The dimension in `input` corresponding to the elements of
`bias`. The value of `dim` is ignored if `b` is not specified.
Defaults to 1.
act (str): Name of the activation function to evaluate, or `"linear"`
to disable. Can be e.g. "relu", "lrelu", "tanh", "sigmoid",
"swish", etc. See `activation_funcs` for a full list. `None` is not
allowed. Defaults to `linear`.
alpha (float or int): Shape parameter for the activation
function, or `None` to use the default. Defaults to None.
gain (float): Scaling factor for the output tensor, or `None`
to use default. See `activation_funcs` for the default scaling of
each activation function. If unsure, consider specifying 1.
Defaults to None.
clamp (float): Clamp the output values to `[-clamp, +clamp]`,
or `None` to disable the clamping (default). Defaults to None.
use_custom_op (bool): Whether to use customized op.
Defaults to True.
Returns:
torch.Tensor: Tensor of the same shape and datatype as `input`.
"""
assert isinstance(input, torch.Tensor)
if use_custom_op and input.is_cuda:
return _bias_act_cuda(
dim=dim, act=act, alpha=alpha, gain=gain,
clamp=clamp).apply(input, bias)
return _bias_act_ref(
input=input,
bias=bias,
dim=dim,
act=act,
alpha=alpha,
gain=gain,
clamp=clamp)
def _bias_act_ref(input: torch.Tensor,
bias: Optional[torch.Tensor] = None,
dim: int = 1,
act: str = 'linear',
alpha: Optional[Union[float, int]] = None,
gain: Optional[float] = None,
clamp: Optional[float] = None):
"""Slow reference implementation of `bias_act()` using standard PyTorch
ops.
Adds `bias` to activation tensor `input`, and evaluates activation
function `act`, and scales the result by `gain`. Each of the steps is
optional.
In most cases, the fused op is considerably more efficient than performing
the same calculation using standard PyTorch ops. It supports first and
second order gradients, but not third order gradients.
Args:
input (torch.Tensor): Input activation tensor. Can be of any shape.
bias (torch.Tensor): Bias vector, or `None` to disable.
Must be a 1D tensor of the same type as `input`. The shape must
be known, and it must match the dimension of `input` corresponding
to `dim`. Defaults to None.
dim (int): The dimension in `input` corresponding to the elements of
`bias`. The value of `dim` is ignored if `b` is not specified.
Defaults to 1.
act (str): Name of the activation function to evaluate, or `"linear"`
to disable. Can be e.g. "relu", "lrelu", "tanh", "sigmoid",
"swish", etc. See `activation_funcs` for a full list. `None` is not
allowed. Defaults to `linear`.
alpha (float or int): Shape parameter for the activation
function, or `None` to use the default. Defaults to None.
gain (float): Scaling factor for the output tensor, or `None`
to use default. See `activation_funcs` for the default scaling of
each activation function. If unsure, consider specifying 1.
Defaults to None.
clamp (float): Clamp the output values to
`[-clamp, +clamp]`, or `None` to disable the clamping (default).
Defaults to None.
Returns:
torch.Tensor: Tensor of the same shape and datatype as `input`.
"""
assert isinstance(input, torch.Tensor)
assert clamp is None or clamp >= 0
spec = activation_funcs[act]
alpha = float(alpha if alpha is not None else spec.def_alpha)
gain = float(gain if gain is not None else spec.def_gain)
clamp = float(clamp if clamp is not None else -1)
# Add bias.
if bias is not None:
assert isinstance(bias, torch.Tensor) and bias.ndim == 1
assert 0 <= dim < input.ndim
assert bias.shape[0] == input.shape[dim]
input = input + bias.reshape(
[-1 if i == dim else 1 for i in range(input.ndim)])
# Evaluate activation function.
alpha = float(alpha)
output = spec.func(input, alpha=alpha)
# Scale by gain.
gain = float(gain)
if gain != 1:
output = output * gain
# Clamp.
if clamp >= 0:
# pylint: disable=invalid-unary-operand-type
output = output.clamp(-clamp, clamp)
return output
_bias_act_cuda_cache: Dict = dict()
def _bias_act_cuda(dim: int = 1,
act: str = 'linear',
alpha: Optional[Union[float, int]] = None,
gain: Optional[float] = None,
clamp: Optional[float] = None):
""""Fast CUDA implementation of `bias_act()` using custom ops.
Args:
dim (int): The dimension in `x` corresponding to the elements of `b`.
The value of `dim` is ignored if `b` is not specified.
Defaults to 1.
act (str): Name of the activation function to evaluate, or `"linear"`
to disable. Can be e.g. "relu", "lrelu", "tanh", "sigmoid",
"swish", etc. See `activation_funcs` for a full list. `None` is not
allowed. Defaults to `linear`.
alpha (float | int): Shape parameter for the activation
function, or `None` to use the default. Defaults to None.
gain (float): Scaling factor for the output tensor, or `None`
to use default. See `activation_funcs` for the default scaling of
each activation function. If unsure, consider specifying 1.
Defaults to None.
clamp (float): Clamp the output values to `[-clamp, +clamp]`,
or `None` to disable the clamping (default). Defaults to None.
Returns:
torch.Tensor: Tensor of the same shape and datatype as `x`.
"""
# Parse arguments.
assert clamp is None or clamp >= 0
spec = activation_funcs[act]
alpha = float(alpha if alpha is not None else spec.def_alpha)
gain = float(gain if gain is not None else spec.def_gain)
clamp = float(clamp if clamp is not None else -1)
# Lookup from cache.
key = (dim, act, alpha, gain, clamp)
if key in _bias_act_cuda_cache:
return _bias_act_cuda_cache[key]
# Forward op.
class BiasActCuda(torch.autograd.Function):
@staticmethod
def forward(ctx, x, b): # pylint: disable=arguments-differ
ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride(
1) == 1 else torch.contiguous_format
x = x.contiguous(memory_format=ctx.memory_format)
b = b.contiguous() if b is not None else _null_tensor.to(x.device)
y = x
if act != 'linear' or gain != 1 or clamp >= 0 or (
b is not _null_tensor.to(x.device)):
y = ext_module.bias_act(x, b, _null_tensor.to(x.device),
_null_tensor.to(x.device),
_null_tensor.to(x.device), 0, dim,
spec.cuda_idx, alpha, gain, clamp)
ctx.save_for_backward(
x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor.to(
x.device), b if 'x' in spec.ref or spec.has_2nd_grad else
_null_tensor.to(x.device),
y if 'y' in spec.ref else _null_tensor.to(x.device))
return y
@staticmethod
def backward(ctx, dy): # pylint: disable=arguments-differ
dy = dy.contiguous(memory_format=ctx.memory_format)
x, b, y = ctx.saved_tensors
dx = None
db = None
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
dx = dy
if act != 'linear' or gain != 1 or clamp >= 0:
dx = BiasActCudaGrad.apply(dy, x, b, y)
if ctx.needs_input_grad[1]:
db = dx.sum([i for i in range(dx.ndim) if i != dim])
return dx, db
# Backward op.
class BiasActCudaGrad(torch.autograd.Function):
@staticmethod
def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
ctx.memory_format = torch.channels_last if dy.ndim > 2 and (
dy.stride(1) == 1) else torch.contiguous_format
dx = ext_module.bias_act(dy, b, x, y, _null_tensor.to(x.device), 1,
dim, spec.cuda_idx, alpha, gain, clamp)
ctx.save_for_backward(
dy if spec.has_2nd_grad else _null_tensor.to(x.device), x, b,
y)
return dx
@staticmethod
def backward(ctx, d_dx): # pylint: disable=arguments-differ
d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
dy, x, b, y = ctx.saved_tensors
d_dy = None
d_x = None
d_b = None
d_y = None
if ctx.needs_input_grad[0]:
d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
if spec.has_2nd_grad and (ctx.needs_input_grad[1]
or ctx.needs_input_grad[2]):
d_x = ext_module.bias_act(d_dx, b, x, y, dy, 2, dim,
spec.cuda_idx, alpha, gain, clamp)
if spec.has_2nd_grad and ctx.needs_input_grad[2]:
d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
return d_dy, d_x, d_b, d_y
# Add to cache.
_bias_act_cuda_cache[key] = BiasActCuda
return BiasActCuda
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
# source: https://github.com/NVlabs/stylegan3/blob/main/torch_utils/ops/conv2d_gradfix.py # noqa
"""Custom replacement for `torch.nn.functional.conv2d` that supports
arbitrarily high order gradients with zero performance penalty."""
import contextlib
import warnings
from typing import Dict, Optional, Tuple, Union
import torch
enabled = True
weight_gradients_disabled = False
@contextlib.contextmanager
def no_weight_gradients(disable=True):
global weight_gradients_disabled
old = weight_gradients_disabled
if disable:
weight_gradients_disabled = True
yield
weight_gradients_disabled = old
def conv2d(input: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
stride: Union[int, Tuple[int, ...]] = 1,
padding: Union[int, Tuple[int, ...]] = 0,
dilation: Union[int, Tuple[int, ...]] = 1,
groups: int = 1):
flag = True
if torch.__version__ >= '1.10.0':
warnings.warn('Since '
'aten:cudnn_convolution_backward_weight is '
f'not supported in torch=={torch.__version__},'
' rolling back to `torch.nn.functional.conv2d`')
flag = False
if _should_use_custom_op(input) and flag:
return _conv2d_gradfix(
transpose=False,
weight_shape=weight.shape,
stride=stride,
padding=padding,
output_padding=0,
dilation=dilation,
groups=groups).apply(input, weight, bias)
return torch.nn.functional.conv2d(
input=input,
weight=weight,
bias=bias,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups)
def conv_transpose2d(input: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
stride: Union[int, Tuple[int, ...]] = 1,
padding: Union[int, Tuple[int, ...]] = 0,
output_padding: Union[int, Tuple[int, ...]] = 0,
groups: int = 1,
dilation: Union[int, Tuple[int, ...]] = 1):
if _should_use_custom_op(input):
return _conv2d_gradfix(
transpose=True,
weight_shape=weight.shape,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation).apply(input, weight, bias)
return torch.nn.functional.conv_transpose2d(
input=input,
weight=weight,
bias=bias,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation)
def _should_use_custom_op(input):
assert isinstance(input, torch.Tensor)
if (not enabled) or (not torch.backends.cudnn.enabled):
return False
if input.device.type != 'cuda':
return False
return True
def _to_tuple(x, ndim):
xs = tuple(x) if isinstance(x, (tuple, list)) else (x, ) * ndim
assert len(xs) == ndim
assert all(isinstance(x, int) for x in xs)
return xs
_conv2d_gradfix_cache: Dict = dict()
_null_tensor = torch.empty([0])
def _conv2d_gradfix(
transpose: bool,
weight_shape: Tuple[int, ...],
stride: Union[int, Tuple[int, ...]],
padding: Union[int, Tuple[int, ...]],
output_padding: Union[int, Tuple[int, ...]],
dilation: Union[int, Tuple[int, ...]],
groups: int,
):
# Parse arguments.
ndim = 2
weight_shape = tuple(weight_shape)
stride = _to_tuple(stride, ndim)
padding = _to_tuple(padding, ndim)
output_padding = _to_tuple(output_padding, ndim)
dilation = _to_tuple(dilation, ndim)
# Lookup from cache.
key = (transpose, weight_shape, stride, padding, output_padding, dilation,
groups)
if key in _conv2d_gradfix_cache:
return _conv2d_gradfix_cache[key]
# Validate arguments.
assert groups >= 1
assert len(weight_shape) == ndim + 2
assert all(stride[i] >= 1 for i in range(ndim)) # type: ignore
assert all(padding[i] >= 0 for i in range(ndim)) # type: ignore
assert all(dilation[i] >= 0 for i in range(ndim)) # type: ignore
if not transpose:
assert all(output_padding[i] == 0 for i in range(ndim)) # type: ignore
else: # transpose
for i in range(ndim):
assert 0 <= output_padding[i] < max( # type: ignore
stride[i], # type: ignore
dilation[i]) # type: ignore
# Helpers.
common_kwargs = dict(
stride=stride, padding=padding, dilation=dilation, groups=groups)
def calc_output_padding(input_shape, output_shape):
if transpose:
return [0, 0]
return [
input_shape[i + 2] - (output_shape[i + 2] - 1) * stride[i] -
(1 - 2 * padding[i]) - dilation[i] * (weight_shape[i + 2] - 1)
for i in range(ndim)
]
# Forward & backward.
class Conv2d(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, bias):
assert weight.shape == weight_shape
ctx.save_for_backward(
input if weight.requires_grad else _null_tensor,
weight if input.requires_grad else _null_tensor,
)
ctx.input_shape = input.shape
# Simple 1x1 convolution => cuBLAS (only on Volta, not on Ampere).
if weight_shape[2:] == stride == dilation == (
1, 1) and padding == (
0, 0) and torch.cuda.get_device_capability(
input.device) < (8, 0):
a = weight.reshape(groups, weight_shape[0] // groups,
weight_shape[1])
b = input.reshape(input.shape[0], groups,
input.shape[1] // groups, -1)
c = (a.transpose(1, 2) if transpose else a) @ b.permute(
1, 2, 0, 3).flatten(2)
c = c.reshape(-1, input.shape[0],
*input.shape[2:]).transpose(0, 1)
c = c if bias is None else c + bias.unsqueeze(0).unsqueeze(
2).unsqueeze(3)
return c.contiguous(
memory_format=(torch.channels_last if input.stride(1) ==
1 else torch.contiguous_format))
# General case => cuDNN.
if transpose:
return torch.nn.functional.conv_transpose2d(
input=input,
weight=weight,
bias=bias,
output_padding=output_padding,
**common_kwargs)
return torch.nn.functional.conv2d(
input=input, weight=weight, bias=bias, **common_kwargs)
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
input_shape = ctx.input_shape
grad_input = None
grad_weight = None
grad_bias = None
if ctx.needs_input_grad[0]:
p = calc_output_padding(
input_shape=input_shape, output_shape=grad_output.shape)
op = _conv2d_gradfix(
transpose=(not transpose),
weight_shape=weight_shape,
output_padding=p,
**common_kwargs)
grad_input = op.apply(grad_output, weight, None)
assert grad_input.shape == input_shape
if ctx.needs_input_grad[1] and not weight_gradients_disabled:
grad_weight = Conv2dGradWeight.apply(grad_output, input)
assert grad_weight.shape == weight_shape
if ctx.needs_input_grad[2]:
grad_bias = grad_output.sum([0, 2, 3])
return grad_input, grad_weight, grad_bias
# Gradient with respect to the weights.
class Conv2dGradWeight(torch.autograd.Function):
@staticmethod
def forward(ctx, grad_output, input):
ctx.save_for_backward(
grad_output if input.requires_grad else _null_tensor,
input if grad_output.requires_grad else _null_tensor,
)
ctx.grad_output_shape = grad_output.shape
ctx.input_shape = input.shape
# Simple 1x1 convolution => cuBLAS (on both Volta and Ampere).
if weight_shape[2:] == stride == dilation == (
1, 1) and padding == (0, 0):
a = grad_output.reshape(grad_output.shape[0], groups,
grad_output.shape[1] // groups,
-1).permute(1, 2, 0, 3).flatten(2)
b = input.reshape(input.shape[0], groups,
input.shape[1] // groups,
-1).permute(1, 2, 0, 3).flatten(2)
c = (b @ a.transpose(1, 2) if transpose else
a @ b.transpose(1, 2)).reshape(weight_shape)
return c.contiguous(
memory_format=(torch.channels_last if input.stride(1) ==
1 else torch.contiguous_format))
# General case => cuDNN.
name = ('aten::cudnn_convolution_transpose_backward_weight' if
transpose else 'aten::cudnn_convolution_backward_weight')
flags = [
torch.backends.cudnn.benchmark,
torch.backends.cudnn.deterministic,
torch.backends.cudnn.allow_tf32
]
return torch._C._jit_get_operation(name)(weight_shape, grad_output,
input, padding, stride,
dilation, groups, *flags)
@staticmethod
def backward(ctx, grad2_grad_weight):
grad_output, input = ctx.saved_tensors
grad_output_shape = ctx.grad_output_shape
input_shape = ctx.input_shape
grad2_grad_output = None
grad2_input = None
if ctx.needs_input_grad[0]:
grad2_grad_output = Conv2d.apply(input, grad2_grad_weight,
None)
assert grad2_grad_output.shape == grad_output_shape
if ctx.needs_input_grad[1]:
p = calc_output_padding(
input_shape=input_shape, output_shape=grad_output_shape)
op = _conv2d_gradfix(
transpose=(not transpose),
weight_shape=weight_shape,
output_padding=p,
**common_kwargs)
grad2_input = op.apply(grad_output, grad2_grad_weight, None)
assert grad2_input.shape == input_shape
return grad2_grad_output, grad2_input
_conv2d_gradfix_cache[key] = Conv2d
return Conv2d
#include "pytorch_cpp_helper.hpp"
#include "pytorch_device_registry.hpp"
torch::Tensor bias_act_op_impl(const torch::Tensor &input,
const torch::Tensor &bias,
const torch::Tensor &xref,
const torch::Tensor &yref,
const torch::Tensor &dy, int grad, int dim,
int act, float alpha, float gain, float clamp) {
return DISPATCH_DEVICE_IMPL(bias_act_op_impl, input, bias, xref, yref, dy,
grad, dim, act, alpha, gain, clamp);
}
torch::Tensor bias_act(const torch::Tensor &input, const torch::Tensor &bias,
const torch::Tensor &xref, const torch::Tensor &yref,
const torch::Tensor &dy, int grad, int dim, int act,
float alpha, float gain, float clamp) {
return bias_act_op_impl(input, bias, xref, yref, dy, grad, dim, act, alpha,
gain, clamp);
}
// Modified from
// https://github.com/NVlabs/stylegan3/blob/main/torch_utils/ops/bias_act.cpp
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
#include <c10/util/Half.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
#include "pytorch_cuda_helper.hpp"
struct bias_act_kernel_params {
const void* x; // [sizeX]
const void* b; // [sizeB] or NULL
const void* xref; // [sizeX] or NULL
const void* yref; // [sizeX] or NULL
const void* dy; // [sizeX] or NULL
void* y; // [sizeX]
int grad;
int act;
float alpha;
float gain;
float clamp;
int sizeX;
int sizeB;
int stepB;
int loopX;
};
// CUDA kernel selection.
template <class T>
void* choose_bias_act_kernel(const bias_act_kernel_params& p);
//------------------------------------------------------------------------
// Helpers.
template <class T>
struct InternalType;
template <>
struct InternalType<double> {
typedef double scalar_t;
};
template <>
struct InternalType<float> {
typedef float scalar_t;
};
template <>
struct InternalType<c10::Half> {
typedef float scalar_t;
};
//------------------------------------------------------------------------
// CUDA kernel.
template <class T, int A>
__global__ void bias_act_kernel(bias_act_kernel_params p) {
typedef typename InternalType<T>::scalar_t scalar_t;
int G = p.grad;
scalar_t alpha = (scalar_t)p.alpha;
scalar_t gain = (scalar_t)p.gain;
scalar_t clamp = (scalar_t)p.clamp;
scalar_t one = (scalar_t)1;
scalar_t two = (scalar_t)2;
scalar_t expRange = (scalar_t)80;
scalar_t halfExpRange = (scalar_t)40;
scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
// Loop over elements.
int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX;
loopIdx++, xi += blockDim.x) {
// Load.
scalar_t x = (scalar_t)((const T*)p.x)[xi];
scalar_t b =
(p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
scalar_t yy = (gain != 0) ? yref / gain : 0;
scalar_t y = 0;
// Apply bias.
((G == 0) ? x : xref) += b;
// linear
if (A == 1) {
if (G == 0) y = x;
if (G == 1) y = x;
}
// relu
if (A == 2) {
if (G == 0) y = (x > 0) ? x : 0;
if (G == 1) y = (yy > 0) ? x : 0;
}
// lrelu
if (A == 3) {
if (G == 0) y = (x > 0) ? x : x * alpha;
if (G == 1) y = (yy > 0) ? x : x * alpha;
}
// tanh
if (A == 4) {
if (G == 0) {
scalar_t c = exp(x);
scalar_t d = one / c;
y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d);
}
if (G == 1) y = x * (one - yy * yy);
if (G == 2) y = x * (one - yy * yy) * (-two * yy);
}
// sigmoid
if (A == 5) {
if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
if (G == 1) y = x * yy * (one - yy);
if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
}
// elu
if (A == 6) {
if (G == 0) y = (x >= 0) ? x : exp(x) - one;
if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
}
// selu
if (A == 7) {
if (G == 0)
y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
if (G == 1)
y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
}
// softplus
if (A == 8) {
if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
if (G == 1) y = x * (one - exp(-yy));
if (G == 2) {
scalar_t c = exp(-yy);
y = x * c * (one - c);
}
}
// swish
if (A == 9) {
if (G == 0)
y = (x < -expRange) ? 0 : x / (exp(-x) + one);
else {
scalar_t c = exp(xref);
scalar_t d = c + one;
if (G == 1)
y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
else
y = (xref > halfExpRange)
? 0
: x * c * (xref * (two - d) + two * d) / (d * d * d);
yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
}
}
// Apply gain.
y *= gain * dy;
// Clamp.
if (clamp >= 0) {
if (G == 0)
y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
else
y = (yref > -clamp & yref < clamp) ? y : 0;
}
// Store.
((T*)p.y)[xi] = (T)y;
}
}
//------------------------------------------------------------------------
// CUDA kernel selection.
template <class T>
void* choose_bias_act_kernel(const bias_act_kernel_params& p) {
if (p.act == 1) return (void*)bias_act_kernel<T, 1>;
if (p.act == 2) return (void*)bias_act_kernel<T, 2>;
if (p.act == 3) return (void*)bias_act_kernel<T, 3>;
if (p.act == 4) return (void*)bias_act_kernel<T, 4>;
if (p.act == 5) return (void*)bias_act_kernel<T, 5>;
if (p.act == 6) return (void*)bias_act_kernel<T, 6>;
if (p.act == 7) return (void*)bias_act_kernel<T, 7>;
if (p.act == 8) return (void*)bias_act_kernel<T, 8>;
if (p.act == 9) return (void*)bias_act_kernel<T, 9>;
return NULL;
}
//------------------------------------------------------------------------
static bool has_same_layout(torch::Tensor x, torch::Tensor y) {
if (x.dim() != y.dim()) return false;
for (int64_t i = 0; i < x.dim(); i++) {
if (x.size(i) != y.size(i)) return false;
if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) return false;
}
return true;
}
//------------------------------------------------------------------------
torch::Tensor bias_act_op(const torch::Tensor& x, const torch::Tensor& b,
const torch::Tensor& xref, const torch::Tensor& yref,
const torch::Tensor& dy, int grad, int dim, int act,
float alpha, float gain, float clamp) {
// Validate arguments.
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
TORCH_CHECK(
b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()),
"b must have the same dtype and device as x");
TORCH_CHECK(xref.numel() == 0 ||
(xref.sizes() == x.sizes() && xref.dtype() == x.dtype() &&
xref.device() == x.device()),
"xref must have the same shape, dtype, and device as x");
TORCH_CHECK(yref.numel() == 0 ||
(yref.sizes() == x.sizes() && yref.dtype() == x.dtype() &&
yref.device() == x.device()),
"yref must have the same shape, dtype, and device as x");
TORCH_CHECK(
dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() &&
dy.device() == x.device()),
"dy must have the same dtype and device as x");
TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
TORCH_CHECK(b.dim() == 1, "b must have rank 1");
TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()),
"dim is out of bounds");
TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim),
"b has wrong number of elements");
TORCH_CHECK(grad >= 0, "grad must be non-negative");
// Validate layout.
TORCH_CHECK(x.is_non_overlapping_and_dense(),
"x must be non-overlapping and dense");
TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x),
"xref must have the same layout as x");
TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x),
"yref must have the same layout as x");
TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x),
"dy must have the same layout as x");
// Create output tensor.
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
torch::Tensor y = torch::empty_like(x);
TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
// Initialize CUDA kernel parameters.
bias_act_kernel_params p;
p.x = x.data_ptr();
p.b = (b.numel()) ? b.data_ptr() : NULL;
p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
p.y = y.data_ptr();
p.grad = grad;
p.act = act;
p.alpha = alpha;
p.gain = gain;
p.clamp = clamp;
p.sizeX = (int)x.numel();
p.sizeB = (int)b.numel();
p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
// Choose CUDA kernel.
void* kernel;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
kernel = choose_bias_act_kernel<scalar_t>(p);
});
TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
// Launch CUDA kernel.
p.loopX = 4;
int blockSize = 4 * 32;
int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
void* args[] = {&p};
AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0,
at::cuda::getCurrentCUDAStream()));
return y;
}
This diff is collapsed.
This diff is collapsed.
#include "pytorch_cpp_helper.hpp"
#include "pytorch_device_registry.hpp"
std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu_op_impl(
torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b,
torch::Tensor si, int up, int down, int px0, int px1, int py0, int py1,
int sx, int sy, float gain, float slope, float clamp, bool flip_filters,
bool writeSigns) {
return DISPATCH_DEVICE_IMPL(filtered_lrelu_op_impl, x, fu, fd, b, si, up,
down, px0, px1, py0, py1, sx, sy, gain, slope,
clamp, flip_filters, writeSigns);
}
std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu(
torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b,
torch::Tensor si, int up, int down, int px0, int px1, int py0, int py1,
int sx, int sy, float gain, float slope, float clamp, bool flip_filters,
bool writeSigns) {
return filtered_lrelu_op_impl(x, fu, fd, b, si, up, down, px0, px1, py0, py1,
sx, sy, gain, slope, clamp, flip_filters,
writeSigns);
}
torch::Tensor filtered_lrelu_act_op_impl(torch::Tensor x, torch::Tensor si,
int sx, int sy, float gain,
float slope, float clamp,
bool writeSigns) {
return DISPATCH_DEVICE_IMPL(filtered_lrelu_act_op_impl, x, si, sx, sy, gain,
slope, clamp, writeSigns);
}
torch::Tensor filtered_lrelu_act_(torch::Tensor x, torch::Tensor si, int sx,
int sy, float gain, float slope, float clamp,
bool writeSigns) {
return filtered_lrelu_act_op_impl(x, si, sx, sy, gain, slope, clamp,
writeSigns);
}
......@@ -312,9 +312,9 @@ Tensor nms_rotated(const Tensor dets, const Tensor scores, const Tensor order,
const Tensor dets_sorted, const Tensor labels,
const float iou_threshold, const int multi_label);
Tensor upfirdn2d(const Tensor &input, const Tensor &kernel, int up_x, int up_y,
int down_x, int down_y, int pad_x0, int pad_x1, int pad_y0,
int pad_y1);
Tensor upfirdn2d(torch::Tensor input, torch::Tensor filter, int upx, int upy,
int downx, int downy, int padx0, int padx1, int pady0,
int pady1, bool flip, float gain);
Tensor fused_bias_leakyrelu(const Tensor &input, const Tensor &bias,
const Tensor &refer, int act, int grad, float alpha,
......@@ -439,6 +439,20 @@ void chamfer_distance_backward(const Tensor xyz1, const Tensor xyz2,
Tensor graddist2, Tensor gradxyz1,
Tensor gradxyz2);
Tensor bias_act(const Tensor &input, const Tensor &bias, const Tensor &xref,
const Tensor &yref, const Tensor &dy, int grad, int dim,
int act, float alpha, float gain, float clamp);
std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu(
torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b,
torch::Tensor si, int up, int down, int px0, int px1, int py0, int py1,
int sx, int sy, float gain, float slope, float clamp, bool flip_filters,
bool writeSigns);
torch::Tensor filtered_lrelu_act_(torch::Tensor x, torch::Tensor si, int sx,
int sy, float gain, float slope, float clamp,
bool writeSigns);
void box_iou_quadri(const Tensor boxes1, const Tensor boxes2, Tensor ious,
const int mode_flag, const bool aligned);
......@@ -458,9 +472,9 @@ void bezier_align_backward(Tensor grad_output, Tensor rois, Tensor grad_input,
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)", py::arg("input"),
py::arg("kernel"), py::arg("up_x"), py::arg("up_y"), py::arg("down_x"),
py::arg("down_y"), py::arg("pad_x0"), py::arg("pad_x1"),
py::arg("pad_y0"), py::arg("pad_y1"));
py::arg("filter"), py::arg("upx"), py::arg("upy"), py::arg("downx"),
py::arg("downy"), py::arg("padx0"), py::arg("padx1"), py::arg("pady0"),
py::arg("pady1"), py::arg("flip"), py::arg("gain"));
m.def("fused_bias_leakyrelu", &fused_bias_leakyrelu,
"fused_bias_leakyrelu (CUDA)", py::arg("input"), py::arg("bias"),
py::arg("empty"), py::arg("act"), py::arg("grad"), py::arg("alpha"),
......@@ -902,6 +916,20 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("input"), py::arg("rois"), py::arg("grad_rois"),
py::arg("pooled_height"), py::arg("pooled_width"),
py::arg("spatial_scale"));
m.def("bias_act", &bias_act, "bias_act (CUDA)", py::arg("input"),
py::arg("bias"), py::arg("xref"), py::arg("yref"), py::arg("dy"),
py::arg("grad"), py::arg("dim"), py::arg("act"), py::arg("alpha"),
py::arg("gain"), py::arg("clamp"));
m.def("filtered_lrelu", &filtered_lrelu, "filtered_lrelu (CUDA)",
py::arg("x"), py::arg("fu"), py::arg("fd"), py::arg("b"), py::arg("si"),
py::arg("up"), py::arg("down"), py::arg("px0"), py::arg("px1"),
py::arg("py0"), py::arg("py1"), py::arg("sx"), py::arg("sy"),
py::arg("gain"), py::arg("slope"), py::arg("clamp"),
py::arg("flip_filters"), py::arg("writeSigns"));
m.def("filtered_lrelu_act_", &filtered_lrelu_act_,
"filtered_lrelu_act_ (CUDA)", py::arg("x"), py::arg("si"),
py::arg("sx"), py::arg("sy"), py::arg("gain"), py::arg("slope"),
py::arg("clamp"), py::arg("writeSigns"));
m.def("box_iou_quadri", &box_iou_quadri, "IoU for quadrilateral boxes",
py::arg("boxes1"), py::arg("boxes2"), py::arg("ious"),
py::arg("mode_flag"), py::arg("aligned"));
......
......@@ -102,17 +102,17 @@ THE POSSIBILITY OF SUCH DAMAGES.
#include "pytorch_cpp_helper.hpp"
#include "pytorch_device_registry.hpp"
torch::Tensor upfirdn2d_op_impl(const torch::Tensor& input,
const torch::Tensor& kernel, int up_x, int up_y,
int down_x, int down_y, int pad_x0, int pad_x1,
int pad_y0, int pad_y1) {
return DISPATCH_DEVICE_IMPL(upfirdn2d_op_impl, input, kernel, up_x, up_y,
down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
torch::Tensor upfirdn2d_op_impl(torch::Tensor input, torch::Tensor filter,
int upx, int upy, int downx, int downy,
int padx0, int padx1, int pady0, int pady1,
bool flip, float gain) {
return DISPATCH_DEVICE_IMPL(upfirdn2d_op_impl, input, filter, upx, upy, downx,
downy, padx0, padx1, pady0, pady1, flip, gain);
}
torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
int up_x, int up_y, int down_x, int down_y, int pad_x0,
int pad_x1, int pad_y0, int pad_y1) {
return upfirdn2d_op_impl(input, kernel, up_x, up_y, down_x, down_y, pad_x0,
pad_x1, pad_y0, pad_y1);
torch::Tensor upfirdn2d(torch::Tensor input, torch::Tensor filter, int upx,
int upy, int downx, int downy, int padx0, int padx1,
int pady0, int pady1, bool flip, float gain) {
return upfirdn2d_op_impl(input, filter, upx, upy, downx, downy, padx0, padx1,
pady0, pady1, flip, gain);
}
This diff is collapsed.
This diff is collapsed.
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmcv.ops import bias_act
from mmcv.ops.bias_act import EasyDict
_USING_PARROTS = True
try:
from parrots.autograd import gradcheck
except ImportError:
from torch.autograd import gradcheck, gradgradcheck
_USING_PARROTS = False
class TestBiasAct:
@classmethod
def setup_class(cls):
cls.input_tensor = torch.randn((1, 3), requires_grad=True)
cls.bias = torch.randn(3, requires_grad=True)
def test_bias_act_cpu(self):
out = bias_act(self.input_tensor, self.bias)
assert out.shape == (1, 3)
# test with different dim
input_tensor = torch.randn((1, 1, 3), requires_grad=True)
bias = torch.randn(3, requires_grad=True)
out = bias_act(input_tensor, bias, dim=2)
assert out.shape == (1, 1, 3)
# test with different act
out = bias_act(self.input_tensor, self.bias, act='relu')
assert out.shape == (1, 3)
out = bias_act(self.input_tensor, self.bias, act='lrelu')
assert out.shape == (1, 3)
out = bias_act(self.input_tensor, self.bias, act='tanh')
assert out.shape == (1, 3)
out = bias_act(self.input_tensor, self.bias, act='sigmoid')
assert out.shape == (1, 3)
out = bias_act(self.input_tensor, self.bias, act='elu')
assert out.shape == (1, 3)
out = bias_act(self.input_tensor, self.bias, act='selu')
assert out.shape == (1, 3)
out = bias_act(self.input_tensor, self.bias, act='softplus')
assert out.shape == (1, 3)
out = bias_act(self.input_tensor, self.bias, act='swish')
assert out.shape == (1, 3)
# test with different alpha
out = bias_act(self.input_tensor, self.bias, act='lrelu', alpha=0.1)
assert out.shape == (1, 3)
# test with different gain
out1 = bias_act(self.input_tensor, self.bias, act='lrelu', gain=0.2)
out2 = bias_act(self.input_tensor, self.bias, act='lrelu', gain=0.1)
assert torch.allclose(out1, out2 * 2)
# test with different clamp
out1 = bias_act(self.input_tensor, self.bias, act='lrelu', clamp=0.5)
out2 = bias_act(self.input_tensor, self.bias, act='lrelu', clamp=0.2)
assert out1.max() <= 0.5
assert out2.max() <= 0.5
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
def test_bias_act_cuda(self):
if _USING_PARROTS:
gradcheck(
bias_act, (self.input_tensor.cuda(), self.bias.cuda()),
delta=1e-4,
pt_atol=1e-3)
else:
gradcheck(
bias_act, (self.input_tensor.cuda(), self.bias.cuda()),
eps=1e-4,
atol=1e-3)
gradgradcheck(
bias_act, (self.input_tensor.cuda(), self.bias.cuda()),
eps=1e-4,
atol=1e-3)
out = bias_act(self.input_tensor.cuda(), self.bias.cuda())
assert out.shape == (1, 3)
# test with different dim
input_tensor = torch.randn((1, 1, 3), requires_grad=True).cuda()
bias = torch.randn(3, requires_grad=True).cuda()
out = bias_act(input_tensor, bias, dim=2)
assert out.shape == (1, 1, 3)
# test with different act
out = bias_act(self.input_tensor.cuda(), self.bias.cuda(), act='relu')
assert out.shape == (1, 3)
out = bias_act(self.input_tensor.cuda(), self.bias.cuda(), act='lrelu')
assert out.shape == (1, 3)
out = bias_act(self.input_tensor.cuda(), self.bias.cuda(), act='tanh')
assert out.shape == (1, 3)
out = bias_act(
self.input_tensor.cuda(), self.bias.cuda(), act='sigmoid')
assert out.shape == (1, 3)
out = bias_act(self.input_tensor.cuda(), self.bias.cuda(), act='elu')
assert out.shape == (1, 3)
out = bias_act(self.input_tensor.cuda(), self.bias.cuda(), act='selu')
assert out.shape == (1, 3)
out = bias_act(
self.input_tensor.cuda(), self.bias.cuda(), act='softplus')
assert out.shape == (1, 3)
out = bias_act(self.input_tensor.cuda(), self.bias.cuda(), act='swish')
assert out.shape == (1, 3)
# test with different alpha
out = bias_act(
self.input_tensor.cuda(), self.bias.cuda(), act='lrelu', alpha=0.1)
assert out.shape == (1, 3)
# test with different gain
out1 = bias_act(
self.input_tensor.cuda(), self.bias.cuda(), act='lrelu', gain=0.2)
out2 = bias_act(
self.input_tensor.cuda(), self.bias.cuda(), act='lrelu', gain=0.1)
assert torch.allclose(out1, out2 * 2)
# test with different clamp
out1 = bias_act(
self.input_tensor.cuda(), self.bias.cuda(), act='lrelu', clamp=0.5)
out2 = bias_act(
self.input_tensor.cuda(), self.bias.cuda(), act='lrelu', clamp=0.2)
assert out1.max() <= 0.5
assert out2.max() <= 0.5
def test_easy_dict(self):
easy_dict = EasyDict(
func=lambda x, **_: x,
def_alpha=0,
def_gain=1,
cuda_idx=1,
ref='',
has_2nd_grad=False)
_ = easy_dict.def_alpha
easy_dict.def_alpha = 1
del easy_dict.def_alpha
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
import torch.nn as nn
from torch.autograd import gradcheck, gradgradcheck
from mmcv.ops import conv2d, conv_transpose2d
class TestCond2d:
@classmethod
def setup_class(cls):
cls.input = torch.randn((1, 3, 32, 32), requires_grad=True)
cls.weight = nn.Parameter(torch.randn(1, 3, 3, 3))
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
def test_conv2d_cuda(self):
x = self.input.cuda()
weight = self.weight.cuda()
res = conv2d(x, weight, None, 1, 1)
assert res.shape == (1, 1, 32, 32)
gradcheck(conv2d, (x, weight, None, 1, 1), eps=1e-2, atol=1e-2)
gradgradcheck(conv2d, (x, weight, None, 1, 1), eps=1e-2, atol=1e-2)
class TestCond2dTansposed:
@classmethod
def setup_class(cls):
cls.input = torch.randn((1, 3, 32, 32), requires_grad=True)
cls.weight = nn.Parameter(torch.randn(3, 1, 3, 3))
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
def test_conv2d_transposed_cuda(self):
x = self.input.cuda()
weight = self.weight.cuda()
res = conv_transpose2d(x, weight, None, 1, 1)
assert res.shape == (1, 1, 32, 32)
gradcheck(
conv_transpose2d, (x, weight, None, 1, 1), eps=1e-2, atol=1e-2)
gradgradcheck(
conv_transpose2d, (x, weight, None, 1, 1), eps=1e-2, atol=1e-2)
This diff is collapsed.
......@@ -56,3 +56,29 @@ class TestUpFirDn2d:
self.input_tensor).cuda(), self.factor, 1, self.pad),
eps=1e-4,
atol=1e-3)
# test with different up
kernel = torch.randn(3, 3)
out = upfirdn2d(
self.input_tensor.cuda(), filter=kernel.cuda(), up=2, padding=1)
assert out.shape == (2, 3, 8, 8)
# test with different down
input_tensor = torch.randn(2, 3, 8, 8)
out = upfirdn2d(
input_tensor.cuda(), filter=self.kernel.cuda(), down=2, padding=1)
assert out.shape == (2, 3, 4, 4)
# test with different flip_filter
out = upfirdn2d(
self.input_tensor.cuda(),
filter=self.kernel.cuda(),
flip_filter=True)
assert out.shape == (2, 3, 1, 1)
# test with different gain
out1 = upfirdn2d(
self.input_tensor.cuda(), filter=self.kernel.cuda(), gain=0.2)
out2 = upfirdn2d(
self.input_tensor.cuda(), filter=self.kernel.cuda(), gain=0.1)
assert torch.allclose(out1, out2 * 2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment