Commit 6f3c5f1c authored by limm's avatar limm
Browse files

support v1.4.0

parent 6f674c7e
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Callable, Union
import cv2
import numpy as np
def imconvert(img: np.ndarray, src: str, dst: str) -> np.ndarray:
def imconvert(img, src, dst):
"""Convert an image from the src colorspace to dst colorspace.
Args:
......@@ -21,7 +19,7 @@ def imconvert(img: np.ndarray, src: str, dst: str) -> np.ndarray:
return out_img
def bgr2gray(img: np.ndarray, keepdim: bool = False) -> np.ndarray:
def bgr2gray(img, keepdim=False):
"""Convert a BGR image to grayscale image.
Args:
......@@ -38,7 +36,7 @@ def bgr2gray(img: np.ndarray, keepdim: bool = False) -> np.ndarray:
return out_img
def rgb2gray(img: np.ndarray, keepdim: bool = False) -> np.ndarray:
def rgb2gray(img, keepdim=False):
"""Convert a RGB image to grayscale image.
Args:
......@@ -55,7 +53,7 @@ def rgb2gray(img: np.ndarray, keepdim: bool = False) -> np.ndarray:
return out_img
def gray2bgr(img: np.ndarray) -> np.ndarray:
def gray2bgr(img):
"""Convert a grayscale image to BGR image.
Args:
......@@ -69,7 +67,7 @@ def gray2bgr(img: np.ndarray) -> np.ndarray:
return out_img
def gray2rgb(img: np.ndarray) -> np.ndarray:
def gray2rgb(img):
"""Convert a grayscale image to RGB image.
Args:
......@@ -83,7 +81,7 @@ def gray2rgb(img: np.ndarray) -> np.ndarray:
return out_img
def _convert_input_type_range(img: np.ndarray) -> np.ndarray:
def _convert_input_type_range(img):
"""Convert the type and range of the input image.
It converts the input image to np.float32 type and range of [0, 1].
......@@ -111,8 +109,7 @@ def _convert_input_type_range(img: np.ndarray) -> np.ndarray:
return img
def _convert_output_type_range(
img: np.ndarray, dst_type: Union[np.uint8, np.float32]) -> np.ndarray:
def _convert_output_type_range(img, dst_type):
"""Convert the type and range of the image according to dst_type.
It converts the image to desired type and range. If `dst_type` is np.uint8,
......@@ -143,7 +140,7 @@ def _convert_output_type_range(
return img.astype(dst_type)
def rgb2ycbcr(img: np.ndarray, y_only: bool = False) -> np.ndarray:
def rgb2ycbcr(img, y_only=False):
"""Convert a RGB image to YCbCr image.
This function produces the same results as Matlab's `rgb2ycbcr` function.
......@@ -163,7 +160,7 @@ def rgb2ycbcr(img: np.ndarray, y_only: bool = False) -> np.ndarray:
Returns:
ndarray: The converted YCbCr image. The output image has the same type
and range as input image.
and range as input image.
"""
img_type = img.dtype
img = _convert_input_type_range(img)
......@@ -177,7 +174,7 @@ def rgb2ycbcr(img: np.ndarray, y_only: bool = False) -> np.ndarray:
return out_img
def bgr2ycbcr(img: np.ndarray, y_only: bool = False) -> np.ndarray:
def bgr2ycbcr(img, y_only=False):
"""Convert a BGR image to YCbCr image.
The bgr version of rgb2ycbcr.
......@@ -197,7 +194,7 @@ def bgr2ycbcr(img: np.ndarray, y_only: bool = False) -> np.ndarray:
Returns:
ndarray: The converted YCbCr image. The output image has the same type
and range as input image.
and range as input image.
"""
img_type = img.dtype
img = _convert_input_type_range(img)
......@@ -211,7 +208,7 @@ def bgr2ycbcr(img: np.ndarray, y_only: bool = False) -> np.ndarray:
return out_img
def ycbcr2rgb(img: np.ndarray) -> np.ndarray:
def ycbcr2rgb(img):
"""Convert a YCbCr image to RGB image.
This function produces the same results as Matlab's ycbcr2rgb function.
......@@ -230,7 +227,7 @@ def ycbcr2rgb(img: np.ndarray) -> np.ndarray:
Returns:
ndarray: The converted RGB image. The output image has the same type
and range as input image.
and range as input image.
"""
img_type = img.dtype
img = _convert_input_type_range(img) * 255
......@@ -243,7 +240,7 @@ def ycbcr2rgb(img: np.ndarray) -> np.ndarray:
return out_img
def ycbcr2bgr(img: np.ndarray) -> np.ndarray:
def ycbcr2bgr(img):
"""Convert a YCbCr image to BGR image.
The bgr version of ycbcr2rgb.
......@@ -262,7 +259,7 @@ def ycbcr2bgr(img: np.ndarray) -> np.ndarray:
Returns:
ndarray: The converted BGR image. The output image has the same type
and range as input image.
and range as input image.
"""
img_type = img.dtype
img = _convert_input_type_range(img) * 255
......@@ -275,11 +272,11 @@ def ycbcr2bgr(img: np.ndarray) -> np.ndarray:
return out_img
def convert_color_factory(src: str, dst: str) -> Callable:
def convert_color_factory(src, dst):
code = getattr(cv2, f'COLOR_{src.upper()}2{dst.upper()}')
def convert_color(img: np.ndarray) -> np.ndarray:
def convert_color(img):
out_img = cv2.cvtColor(img, code)
return out_img
......
# Copyright (c) OpenMMLab. All rights reserved.
import numbers
from typing import List, Optional, Tuple, Union, no_type_check
import cv2
import numpy as np
from mmengine.utils import to_2tuple
from ..utils import to_2tuple
from .io import imread_backend
try:
......@@ -14,10 +13,7 @@ except ImportError:
Image = None
def _scale_size(
size: Tuple[int, int],
scale: Union[float, int, tuple],
) -> Tuple[int, int]:
def _scale_size(size, scale):
"""Rescale a size by a ratio.
Args:
......@@ -41,47 +37,23 @@ cv2_interp_codes = {
'lanczos': cv2.INTER_LANCZOS4
}
cv2_border_modes = {
'constant': cv2.BORDER_CONSTANT,
'replicate': cv2.BORDER_REPLICATE,
'reflect': cv2.BORDER_REFLECT,
'wrap': cv2.BORDER_WRAP,
'reflect_101': cv2.BORDER_REFLECT_101,
'transparent': cv2.BORDER_TRANSPARENT,
'isolated': cv2.BORDER_ISOLATED
}
# Pillow >=v9.1.0 use a slightly different naming scheme for filters.
# Set pillow_interp_codes according to the naming scheme used.
if Image is not None:
if hasattr(Image, 'Resampling'):
pillow_interp_codes = {
'nearest': Image.Resampling.NEAREST,
'bilinear': Image.Resampling.BILINEAR,
'bicubic': Image.Resampling.BICUBIC,
'box': Image.Resampling.BOX,
'lanczos': Image.Resampling.LANCZOS,
'hamming': Image.Resampling.HAMMING
}
else:
pillow_interp_codes = {
'nearest': Image.NEAREST,
'bilinear': Image.BILINEAR,
'bicubic': Image.BICUBIC,
'box': Image.BOX,
'lanczos': Image.LANCZOS,
'hamming': Image.HAMMING
}
def imresize(
img: np.ndarray,
size: Tuple[int, int],
return_scale: bool = False,
interpolation: str = 'bilinear',
out: Optional[np.ndarray] = None,
backend: Optional[str] = None
) -> Union[Tuple[np.ndarray, float, float], np.ndarray]:
pillow_interp_codes = {
'nearest': Image.NEAREST,
'bilinear': Image.BILINEAR,
'bicubic': Image.BICUBIC,
'box': Image.BOX,
'lanczos': Image.LANCZOS,
'hamming': Image.HAMMING
}
def imresize(img,
size,
return_scale=False,
interpolation='bilinear',
out=None,
backend=None):
"""Resize image to a given size.
Args:
......@@ -98,7 +70,7 @@ def imresize(
Returns:
tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or
`resized_img`.
`resized_img`.
"""
h, w = img.shape[:2]
if backend is None:
......@@ -123,18 +95,15 @@ def imresize(
return resized_img, w_scale, h_scale
@no_type_check
def imresize_to_multiple(
img: np.ndarray,
divisor: Union[int, Tuple[int, int]],
size: Union[int, Tuple[int, int], None] = None,
scale_factor: Union[float, Tuple[float, float], None] = None,
keep_ratio: bool = False,
return_scale: bool = False,
interpolation: str = 'bilinear',
out: Optional[np.ndarray] = None,
backend: Optional[str] = None
) -> Union[Tuple[np.ndarray, float, float], np.ndarray]:
def imresize_to_multiple(img,
divisor,
size=None,
scale_factor=None,
keep_ratio=False,
return_scale=False,
interpolation='bilinear',
out=None,
backend=None):
"""Resize image according to a given size or scale factor and then rounds
up the the resized or rescaled image size to the nearest value that can be
divided by the divisor.
......@@ -161,7 +130,7 @@ def imresize_to_multiple(
Returns:
tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or
`resized_img`.
`resized_img`.
"""
h, w = img.shape[:2]
if size is not None and scale_factor is not None:
......@@ -176,7 +145,7 @@ def imresize_to_multiple(
size = _scale_size((w, h), scale_factor)
divisor = to_2tuple(divisor)
size = tuple(int(np.ceil(s / d)) * d for s, d in zip(size, divisor))
size = tuple([int(np.ceil(s / d)) * d for s, d in zip(size, divisor)])
resized_img, w_scale, h_scale = imresize(
img,
size,
......@@ -190,13 +159,11 @@ def imresize_to_multiple(
return resized_img
def imresize_like(
img: np.ndarray,
dst_img: np.ndarray,
return_scale: bool = False,
interpolation: str = 'bilinear',
backend: Optional[str] = None
) -> Union[Tuple[np.ndarray, float, float], np.ndarray]:
def imresize_like(img,
dst_img,
return_scale=False,
interpolation='bilinear',
backend=None):
"""Resize image to the same size of a given image.
Args:
......@@ -208,15 +175,13 @@ def imresize_like(
Returns:
tuple or ndarray: (`resized_img`, `w_scale`, `h_scale`) or
`resized_img`.
`resized_img`.
"""
h, w = dst_img.shape[:2]
return imresize(img, (w, h), return_scale, interpolation, backend=backend)
def rescale_size(old_size: tuple,
scale: Union[float, int, tuple],
return_scale: bool = False) -> tuple:
def rescale_size(old_size, scale, return_scale=False):
"""Calculate the new size to be rescaled to.
Args:
......@@ -253,13 +218,11 @@ def rescale_size(old_size: tuple,
return new_size
def imrescale(
img: np.ndarray,
scale: Union[float, Tuple[int, int]],
return_scale: bool = False,
interpolation: str = 'bilinear',
backend: Optional[str] = None
) -> Union[np.ndarray, Tuple[np.ndarray, float]]:
def imrescale(img,
scale,
return_scale=False,
interpolation='bilinear',
backend=None):
"""Resize image while keeping the aspect ratio.
Args:
......@@ -286,7 +249,7 @@ def imrescale(
return rescaled_img
def imflip(img: np.ndarray, direction: str = 'horizontal') -> np.ndarray:
def imflip(img, direction='horizontal'):
"""Flip an image horizontally or vertically.
Args:
......@@ -306,7 +269,7 @@ def imflip(img: np.ndarray, direction: str = 'horizontal') -> np.ndarray:
return np.flip(img, axis=(0, 1))
def imflip_(img: np.ndarray, direction: str = 'horizontal') -> np.ndarray:
def imflip_(img, direction='horizontal'):
"""Inplace flip an image horizontally or vertically.
Args:
......@@ -326,33 +289,30 @@ def imflip_(img: np.ndarray, direction: str = 'horizontal') -> np.ndarray:
return cv2.flip(img, -1, img)
def imrotate(img: np.ndarray,
angle: float,
center: Optional[Tuple[float, float]] = None,
scale: float = 1.0,
border_value: int = 0,
interpolation: str = 'bilinear',
auto_bound: bool = False,
border_mode: str = 'constant') -> np.ndarray:
def imrotate(img,
angle,
center=None,
scale=1.0,
border_value=0,
interpolation='bilinear',
auto_bound=False):
"""Rotate an image.
Args:
img (np.ndarray): Image to be rotated.
img (ndarray): Image to be rotated.
angle (float): Rotation angle in degrees, positive values mean
clockwise rotation.
center (tuple[float], optional): Center point (w, h) of the rotation in
the source image. If not specified, the center of the image will be
used.
scale (float): Isotropic scale factor.
border_value (int): Border value used in case of a constant border.
Defaults to 0.
border_value (int): Border value.
interpolation (str): Same as :func:`resize`.
auto_bound (bool): Whether to adjust the image size to cover the whole
rotated image.
border_mode (str): Pixel extrapolation method. Defaults to 'constant'.
Returns:
np.ndarray: The rotated image.
ndarray: The rotated image.
"""
if center is not None and auto_bound:
raise ValueError('`auto_bound` conflicts with `center`')
......@@ -375,12 +335,11 @@ def imrotate(img: np.ndarray,
img,
matrix, (w, h),
flags=cv2_interp_codes[interpolation],
borderMode=cv2_border_modes[border_mode],
borderValue=border_value)
return rotated
def bbox_clip(bboxes: np.ndarray, img_shape: Tuple[int, int]) -> np.ndarray:
def bbox_clip(bboxes, img_shape):
"""Clip bboxes to fit the image shape.
Args:
......@@ -398,9 +357,7 @@ def bbox_clip(bboxes: np.ndarray, img_shape: Tuple[int, int]) -> np.ndarray:
return clipped_bboxes
def bbox_scaling(bboxes: np.ndarray,
scale: float,
clip_shape: Optional[Tuple[int, int]] = None) -> np.ndarray:
def bbox_scaling(bboxes, scale, clip_shape=None):
"""Scaling bboxes w.r.t the box center.
Args:
......@@ -426,12 +383,7 @@ def bbox_scaling(bboxes: np.ndarray,
return scaled_bboxes
def imcrop(
img: np.ndarray,
bboxes: np.ndarray,
scale: float = 1.0,
pad_fill: Union[float, list, None] = None
) -> Union[np.ndarray, List[np.ndarray]]:
def imcrop(img, bboxes, scale=1.0, pad_fill=None):
"""Crop image patches.
3 steps: scale the bboxes -> clip bboxes -> crop and pad.
......@@ -440,7 +392,7 @@ def imcrop(
img (ndarray): Image to be cropped.
bboxes (ndarray): Shape (k, 4) or (4, ), location of cropped bboxes.
scale (float, optional): Scale ratio of bboxes, the default value
1.0 means no scaling.
1.0 means no padding.
pad_fill (Number | list[Number]): Value to be filled for padding.
Default: None, which means no padding.
......@@ -464,12 +416,10 @@ def imcrop(
patch = img[y1:y2 + 1, x1:x2 + 1, ...]
else:
_x1, _y1, _x2, _y2 = tuple(scaled_bboxes[i, :])
patch_h = _y2 - _y1 + 1
patch_w = _x2 - _x1 + 1
if chn == 1:
patch_shape = (patch_h, patch_w)
patch_shape = (_y2 - _y1 + 1, _x2 - _x1 + 1)
else:
patch_shape = (patch_h, patch_w, chn) # type: ignore
patch_shape = (_y2 - _y1 + 1, _x2 - _x1 + 1, chn)
patch = np.array(
pad_fill, dtype=img.dtype) * np.ones(
patch_shape, dtype=img.dtype)
......@@ -487,12 +437,12 @@ def imcrop(
return patches
def impad(img: np.ndarray,
def impad(img,
*,
shape: Optional[Tuple[int, int]] = None,
padding: Union[int, tuple, None] = None,
pad_val: Union[float, List] = 0,
padding_mode: str = 'constant') -> np.ndarray:
shape=None,
padding=None,
pad_val=0,
padding_mode='constant'):
"""Pad the given image to a certain shape or pad on all sides with
specified padding mode and padding value.
......@@ -512,16 +462,16 @@ def impad(img: np.ndarray,
reflect or symmetric. Default: constant.
- constant: pads with a constant value, this value is specified
with pad_val.
with pad_val.
- edge: pads with the last value at the edge of the image.
- reflect: pads with reflection of image without repeating the last
value on the edge. For example, padding [1, 2, 3, 4] with 2
elements on both sides in reflect mode will result in
[3, 2, 1, 2, 3, 4, 3, 2].
- symmetric: pads with reflection of image repeating the last value
on the edge. For example, padding [1, 2, 3, 4] with 2 elements on
both sides in symmetric mode will result in
[2, 1, 1, 2, 3, 4, 4, 3]
- reflect: pads with reflection of image without repeating the
last value on the edge. For example, padding [1, 2, 3, 4]
with 2 elements on both sides in reflect mode will result
in [3, 2, 1, 2, 3, 4, 3, 2].
- symmetric: pads with reflection of image repeating the last
value on the edge. For example, padding [1, 2, 3, 4] with
2 elements on both sides in symmetric mode will result in
[2, 1, 1, 2, 3, 4, 4, 3]
Returns:
ndarray: The padded image.
......@@ -529,9 +479,7 @@ def impad(img: np.ndarray,
assert (shape is not None) ^ (padding is not None)
if shape is not None:
width = max(shape[1] - img.shape[1], 0)
height = max(shape[0] - img.shape[0], 0)
padding = (0, 0, width, height)
padding = (0, 0, shape[1] - img.shape[1], shape[0] - img.shape[0])
# check pad_val
if isinstance(pad_val, tuple):
......@@ -571,9 +519,7 @@ def impad(img: np.ndarray,
return img
def impad_to_multiple(img: np.ndarray,
divisor: int,
pad_val: Union[float, List] = 0) -> np.ndarray:
def impad_to_multiple(img, divisor, pad_val=0):
"""Pad an image to ensure each edge to be multiple to some number.
Args:
......@@ -589,9 +535,7 @@ def impad_to_multiple(img: np.ndarray,
return impad(img, shape=(pad_h, pad_w), pad_val=pad_val)
def cutout(img: np.ndarray,
shape: Union[int, Tuple[int, int]],
pad_val: Union[int, float, tuple] = 0) -> np.ndarray:
def cutout(img, shape, pad_val=0):
"""Randomly cut out a rectangle from the original img.
Args:
......@@ -635,7 +579,7 @@ def cutout(img: np.ndarray,
if img.ndim == 2:
patch_shape = (y2 - y1, x2 - x1)
else:
patch_shape = (y2 - y1, x2 - x1, channels) # type: ignore
patch_shape = (y2 - y1, x2 - x1, channels)
img_cutout = img.copy()
patch = np.array(
......@@ -646,8 +590,7 @@ def cutout(img: np.ndarray,
return img_cutout
def _get_shear_matrix(magnitude: Union[int, float],
direction: str = 'horizontal') -> np.ndarray:
def _get_shear_matrix(magnitude, direction='horizontal'):
"""Generate the shear matrix for transformation.
Args:
......@@ -665,11 +608,11 @@ def _get_shear_matrix(magnitude: Union[int, float],
return shear_matrix
def imshear(img: np.ndarray,
magnitude: Union[int, float],
direction: str = 'horizontal',
border_value: Union[int, Tuple[int, int]] = 0,
interpolation: str = 'bilinear') -> np.ndarray:
def imshear(img,
magnitude,
direction='horizontal',
border_value=0,
interpolation='bilinear'):
"""Shear an image.
Args:
......@@ -693,7 +636,7 @@ def imshear(img: np.ndarray,
elif img.ndim == 3:
channels = img.shape[-1]
if isinstance(border_value, int):
border_value = tuple([border_value] * channels) # type: ignore
border_value = tuple([border_value] * channels)
elif isinstance(border_value, tuple):
assert len(border_value) == channels, \
'Expected the num of elements in tuple equals the channels' \
......@@ -711,13 +654,12 @@ def imshear(img: np.ndarray,
# greater than 3 (e.g. shearing masks whose channels large
# than 3) will raise TypeError in `cv2.warpAffine`.
# Here simply slice the first 3 values in `border_value`.
borderValue=border_value[:3], # type: ignore
borderValue=border_value[:3],
flags=cv2_interp_codes[interpolation])
return sheared
def _get_translate_matrix(offset: Union[int, float],
direction: str = 'horizontal') -> np.ndarray:
def _get_translate_matrix(offset, direction='horizontal'):
"""Generate the translate matrix.
Args:
......@@ -735,11 +677,11 @@ def _get_translate_matrix(offset: Union[int, float],
return translate_matrix
def imtranslate(img: np.ndarray,
offset: Union[int, float],
direction: str = 'horizontal',
border_value: Union[int, tuple] = 0,
interpolation: str = 'bilinear') -> np.ndarray:
def imtranslate(img,
offset,
direction='horizontal',
border_value=0,
interpolation='bilinear'):
"""Translate an image.
Args:
......
# Copyright (c) OpenMMLab. All rights reserved.
import io
import os.path as osp
import warnings
from pathlib import Path
from typing import Optional, Union
import cv2
import mmengine.fileio as fileio
import numpy as np
from cv2 import (IMREAD_COLOR, IMREAD_GRAYSCALE, IMREAD_IGNORE_ORIENTATION,
IMREAD_UNCHANGED)
from mmengine.utils import is_filepath, is_str
from mmcv.utils import check_file_exist, is_str, mkdir_or_exist
try:
from turbojpeg import TJCS_RGB, TJPF_BGR, TJPF_GRAY, TurboJPEG
......@@ -42,7 +40,7 @@ imread_flags = {
imread_backend = 'cv2'
def use_backend(backend: str) -> None:
def use_backend(backend):
"""Select a backend for image decoding.
Args:
......@@ -68,7 +66,7 @@ def use_backend(backend: str) -> None:
raise ImportError('`tifffile` is not installed')
def _jpegflag(flag: str = 'color', channel_order: str = 'bgr'):
def _jpegflag(flag='color', channel_order='bgr'):
channel_order = channel_order.lower()
if channel_order not in ['rgb', 'bgr']:
raise ValueError('channel order must be either "rgb" or "bgr"')
......@@ -84,9 +82,7 @@ def _jpegflag(flag: str = 'color', channel_order: str = 'bgr'):
raise ValueError('flag must be "color" or "grayscale"')
def _pillow2array(img,
flag: str = 'color',
channel_order: str = 'bgr') -> np.ndarray:
def _pillow2array(img, flag='color', channel_order='bgr'):
"""Convert a pillow image to numpy array.
Args:
......@@ -141,13 +137,7 @@ def _pillow2array(img,
return array
def imread(img_or_path: Union[np.ndarray, str, Path],
flag: str = 'color',
channel_order: str = 'bgr',
backend: Optional[str] = None,
file_client_args: Optional[dict] = None,
*,
backend_args: Optional[dict] = None) -> np.ndarray:
def imread(img_or_path, flag='color', channel_order='bgr', backend=None):
"""Read an image.
Args:
......@@ -167,117 +157,78 @@ def imread(img_or_path: Union[np.ndarray, str, Path],
`cv2`, `pillow`, `turbojpeg`, `tifffile`, `None`.
If backend is None, the global imread_backend specified by
``mmcv.use_backend()`` will be used. Default: None.
file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmengine.fileio.FileClient` for details.
Default: None. It will be deprecated in future. Please use
``backend_args`` instead.
Deprecated in version 2.0.0rc4.
backend_args (dict, optional): Instantiates the corresponding file
backend. It may contain `backend` key to specify the file
backend. If it contains, the file backend corresponding to this
value will be used and initialized with the remaining values,
otherwise the corresponding file backend will be selected
based on the prefix of the file path. Defaults to None.
New in version 2.0.0rc4.
Returns:
ndarray: Loaded image array.
Examples:
>>> import mmcv
>>> img_path = '/path/to/img.jpg'
>>> img = mmcv.imread(img_path)
>>> img = mmcv.imread(img_path, flag='color', channel_order='rgb',
... backend='cv2')
>>> img = mmcv.imread(img_path, flag='color', channel_order='bgr',
... backend='pillow')
>>> s3_img_path = 's3://bucket/img.jpg'
>>> # infer the file backend by the prefix s3
>>> img = mmcv.imread(s3_img_path)
>>> # manually set the file backend petrel
>>> img = mmcv.imread(s3_img_path, backend_args={
... 'backend': 'petrel'})
>>> http_img_path = 'http://path/to/img.jpg'
>>> img = mmcv.imread(http_img_path)
>>> img = mmcv.imread(http_img_path, backend_args={
... 'backend': 'http'})
"""
if file_client_args is not None:
warnings.warn(
'"file_client_args" will be deprecated in future. '
'Please use "backend_args" instead', DeprecationWarning)
if backend_args is not None:
raise ValueError(
'"file_client_args" and "backend_args" cannot be set at the '
'same time.')
if backend is None:
backend = imread_backend
if backend not in supported_backends:
raise ValueError(f'backend: {backend} is not supported. Supported '
"backends are 'cv2', 'turbojpeg', 'pillow'")
if isinstance(img_or_path, Path):
img_or_path = str(img_or_path)
if isinstance(img_or_path, np.ndarray):
return img_or_path
elif is_str(img_or_path):
if file_client_args is not None:
file_client = fileio.FileClient.infer_client(
file_client_args, img_or_path)
img_bytes = file_client.get(img_or_path)
check_file_exist(img_or_path,
f'img file does not exist: {img_or_path}')
if backend == 'turbojpeg':
with open(img_or_path, 'rb') as in_file:
img = jpeg.decode(in_file.read(),
_jpegflag(flag, channel_order))
if img.shape[-1] == 1:
img = img[:, :, 0]
return img
elif backend == 'pillow':
img = Image.open(img_or_path)
img = _pillow2array(img, flag, channel_order)
return img
elif backend == 'tifffile':
img = tifffile.imread(img_or_path)
return img
else:
img_bytes = fileio.get(img_or_path, backend_args=backend_args)
return imfrombytes(img_bytes, flag, channel_order, backend)
flag = imread_flags[flag] if is_str(flag) else flag
img = cv2.imread(img_or_path, flag)
if flag == IMREAD_COLOR and channel_order == 'rgb':
cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img)
return img
else:
raise TypeError('"img" must be a numpy array or a str or '
'a pathlib.Path object')
def imfrombytes(content: bytes,
flag: str = 'color',
channel_order: str = 'bgr',
backend: Optional[str] = None) -> np.ndarray:
def imfrombytes(content, flag='color', channel_order='bgr', backend=None):
"""Read an image from bytes.
Args:
content (bytes): Image bytes got from files or other streams.
flag (str): Same as :func:`imread`.
channel_order (str): The channel order of the output, candidates
are 'bgr' and 'rgb'. Default to 'bgr'.
backend (str | None): The image decoding backend type. Options are
`cv2`, `pillow`, `turbojpeg`, `tifffile`, `None`. If backend is
None, the global imread_backend specified by ``mmcv.use_backend()``
will be used. Default: None.
`cv2`, `pillow`, `turbojpeg`, `None`. If backend is None, the
global imread_backend specified by ``mmcv.use_backend()`` will be
used. Default: None.
Returns:
ndarray: Loaded image array.
Examples:
>>> img_path = '/path/to/img.jpg'
>>> with open(img_path, 'rb') as f:
>>> img_buff = f.read()
>>> img = mmcv.imfrombytes(img_buff)
>>> img = mmcv.imfrombytes(img_buff, flag='color', channel_order='rgb')
>>> img = mmcv.imfrombytes(img_buff, backend='pillow')
>>> img = mmcv.imfrombytes(img_buff, backend='cv2')
"""
if backend is None:
backend = imread_backend
if backend not in supported_backends:
raise ValueError(
f'backend: {backend} is not supported. Supported '
"backends are 'cv2', 'turbojpeg', 'pillow', 'tifffile'")
raise ValueError(f'backend: {backend} is not supported. Supported '
"backends are 'cv2', 'turbojpeg', 'pillow'")
if backend == 'turbojpeg':
img = jpeg.decode( # type: ignore
content, _jpegflag(flag, channel_order))
img = jpeg.decode(content, _jpegflag(flag, channel_order))
if img.shape[-1] == 1:
img = img[:, :, 0]
return img
elif backend == 'pillow':
with io.BytesIO(content) as buff:
img = Image.open(buff)
img = _pillow2array(img, flag, channel_order)
return img
elif backend == 'tifffile':
with io.BytesIO(content) as buff:
img = tifffile.imread(buff)
buff = io.BytesIO(content)
img = Image.open(buff)
img = _pillow2array(img, flag, channel_order)
return img
else:
img_np = np.frombuffer(content, np.uint8)
......@@ -288,77 +239,20 @@ def imfrombytes(content: bytes,
return img
def imwrite(img: np.ndarray,
file_path: str,
params: Optional[list] = None,
auto_mkdir: Optional[bool] = None,
file_client_args: Optional[dict] = None,
*,
backend_args: Optional[dict] = None) -> bool:
def imwrite(img, file_path, params=None, auto_mkdir=True):
"""Write image to file.
Warning:
The parameter `auto_mkdir` will be deprecated in the future and every
file clients will make directory automatically.
Args:
img (ndarray): Image array to be written.
file_path (str): Image file path.
params (None or list): Same as opencv :func:`imwrite` interface.
auto_mkdir (bool): If the parent folder of `file_path` does not exist,
whether to create it automatically. It will be deprecated.
file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmengine.fileio.FileClient` for details.
Default: None. It will be deprecated in future. Please use
``backend_args`` instead.
Deprecated in version 2.0.0rc4.
backend_args (dict, optional): Instantiates the corresponding file
backend. It may contain `backend` key to specify the file
backend. If it contains, the file backend corresponding to this
value will be used and initialized with the remaining values,
otherwise the corresponding file backend will be selected
based on the prefix of the file path. Defaults to None.
New in version 2.0.0rc4.
whether to create it automatically.
Returns:
bool: Successful or not.
Examples:
>>> # write to hard disk client
>>> ret = mmcv.imwrite(img, '/path/to/img.jpg')
>>> # infer the file backend by the prefix s3
>>> ret = mmcv.imwrite(img, 's3://bucket/img.jpg')
>>> # manually set the file backend petrel
>>> ret = mmcv.imwrite(img, 's3://bucket/img.jpg', backend_args={
... 'backend': 'petrel'})
"""
if file_client_args is not None:
warnings.warn(
'"file_client_args" will be deprecated in future. '
'Please use "backend_args" instead', DeprecationWarning)
if backend_args is not None:
raise ValueError(
'"file_client_args" and "backend_args" cannot be set at the '
'same time.')
assert is_filepath(file_path)
file_path = str(file_path)
if auto_mkdir is not None:
warnings.warn(
'The parameter `auto_mkdir` will be deprecated in the future and '
'every file clients will make directory automatically.')
img_ext = osp.splitext(file_path)[-1]
# Encode image according to image suffix.
# For example, if image path is '/path/your/img.jpg', the encode
# format is '.jpg'.
flag, img_buff = cv2.imencode(img_ext, img, params)
if file_client_args is not None:
file_client = fileio.FileClient.infer_client(file_client_args,
file_path)
file_client.put(img_buff.tobytes(), file_path)
else:
fileio.put(img_buff.tobytes(), file_path, backend_args=backend_args)
return flag
if auto_mkdir:
dir_name = osp.abspath(osp.dirname(file_path))
mkdir_or_exist(dir_name)
return cv2.imwrite(file_path, img, params)
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
import numpy as np
import mmcv
......@@ -11,24 +9,18 @@ except ImportError:
torch = None
def tensor2imgs(tensor,
mean: Optional[tuple] = None,
std: Optional[tuple] = None,
to_rgb: bool = True) -> list:
"""Convert tensor to 3-channel images or 1-channel gray images.
def tensor2imgs(tensor, mean=(0, 0, 0), std=(1, 1, 1), to_rgb=True):
"""Convert tensor to 3-channel images.
Args:
tensor (torch.Tensor): Tensor that contains multiple images, shape (
N, C, H, W). :math:`C` can be either 3 or 1.
mean (tuple[float], optional): Mean of images. If None,
(0, 0, 0) will be used for tensor with 3-channel,
while (0, ) for tensor with 1-channel. Defaults to None.
std (tuple[float], optional): Standard deviation of images. If None,
(1, 1, 1) will be used for tensor with 3-channel,
while (1, ) for tensor with 1-channel. Defaults to None.
N, C, H, W).
mean (tuple[float], optional): Mean of images. Defaults to (0, 0, 0).
std (tuple[float], optional): Standard deviation of images.
Defaults to (1, 1, 1).
to_rgb (bool, optional): Whether the tensor was converted to RGB
format in the first place. If so, convert it back to BGR.
For the tensor with 1 channel, it must be False. Defaults to True.
Defaults to True.
Returns:
list[np.ndarray]: A list that contains multiple images.
......@@ -37,14 +29,8 @@ def tensor2imgs(tensor,
if torch is None:
raise RuntimeError('pytorch is not installed')
assert torch.is_tensor(tensor) and tensor.ndim == 4
channels = tensor.size(1)
assert channels in [1, 3]
if mean is None:
mean = (0, ) * channels
if std is None:
std = (1, ) * channels
assert (channels == len(mean) == len(std) == 3) or \
(channels == len(mean) == len(std) == 1 and not to_rgb)
assert len(mean) == 3
assert len(std) == 3
num_imgs = tensor.size(0)
mean = np.array(mean, dtype=np.float32)
......
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import Optional
import cv2
import numpy as np
from mmengine.utils import is_tuple_of
from PIL import Image, ImageEnhance
from ..utils import is_tuple_of
from .colorspace import bgr2gray, gray2bgr
from .io import imread_backend
def imnormalize(img, mean, std, to_rgb=True):
......@@ -102,7 +97,7 @@ def posterize(img, bits):
return img
def adjust_color(img, alpha=1, beta=None, gamma=0, backend=None):
def adjust_color(img, alpha=1, beta=None, gamma=0):
r"""It blends the source image and its gray image:
.. math::
......@@ -115,41 +110,22 @@ def adjust_color(img, alpha=1, beta=None, gamma=0, backend=None):
If None, it's assigned the value (1 - `alpha`).
gamma (int | float): Scalar added to each sum.
Same as :func:`cv2.addWeighted`. Default 0.
backend (str | None): The image processing backend type. Options are
`cv2`, `pillow`, `None`. If backend is None, the global
``imread_backend`` specified by ``mmcv.use_backend()`` will be
used. Defaults to None.
Returns:
ndarray: Colored image which has the same size and dtype as input.
"""
if backend is None:
backend = imread_backend
if backend not in ['cv2', 'pillow']:
raise ValueError(f'backend: {backend} is not supported.'
f"Supported backends are 'cv2', 'pillow'")
if backend == 'pillow':
assert img.dtype == np.uint8, 'Pillow backend only support uint8 type'
warnings.warn("Only use 'alpha' for pillow backend.")
# Image.fromarray defaultly supports RGB, not BGR.
pil_image = Image.fromarray(img[..., ::-1], mode='RGB')
enhancer = ImageEnhance.Color(pil_image)
pil_image = enhancer.enhance(alpha)
return np.array(pil_image, dtype=img.dtype)[..., ::-1]
else:
gray_img = bgr2gray(img)
gray_img = np.tile(gray_img[..., None], [1, 1, 3])
if beta is None:
beta = 1 - alpha
colored_img = cv2.addWeighted(img, alpha, gray_img, beta, gamma)
if not colored_img.dtype == np.uint8:
# Note when the dtype of `img` is not the default `np.uint8`
# (e.g. np.float32), the value in `colored_img` got from cv2
# is not guaranteed to be in range [0, 255], so here clip
# is needed.
colored_img = np.clip(colored_img, 0, 255)
return colored_img.astype(img.dtype)
gray_img = bgr2gray(img)
gray_img = np.tile(gray_img[..., None], [1, 1, 3])
if beta is None:
beta = 1 - alpha
colored_img = cv2.addWeighted(img, alpha, gray_img, beta, gamma)
if not colored_img.dtype == np.uint8:
# Note when the dtype of `img` is not the default `np.uint8`
# (e.g. np.float32), the value in `colored_img` got from cv2
# is not guaranteed to be in range [0, 255], so here clip
# is needed.
colored_img = np.clip(colored_img, 0, 255)
return colored_img
def imequalize(img):
......@@ -197,7 +173,7 @@ def imequalize(img):
return equalized_img.astype(img.dtype)
def adjust_brightness(img, factor=1., backend=None):
def adjust_brightness(img, factor=1.):
"""Adjust image brightness.
This function controls the brightness of an image. An
......@@ -214,40 +190,22 @@ def adjust_brightness(img, factor=1., backend=None):
Factor 1.0 returns the original image, lower
factors mean less color (brightness, contrast,
etc), and higher values more. Default 1.
backend (str | None): The image processing backend type. Options are
`cv2`, `pillow`, `None`. If backend is None, the global
``imread_backend`` specified by ``mmcv.use_backend()`` will be
used. Defaults to None.
Returns:
ndarray: The brightened image.
"""
if backend is None:
backend = imread_backend
if backend not in ['cv2', 'pillow']:
raise ValueError(f'backend: {backend} is not supported.'
f"Supported backends are 'cv2', 'pillow'")
if backend == 'pillow':
assert img.dtype == np.uint8, 'Pillow backend only support uint8 type'
# Image.fromarray defaultly supports RGB, not BGR.
pil_image = Image.fromarray(img[..., ::-1], mode='RGB')
enhancer = ImageEnhance.Brightness(pil_image)
pil_image = enhancer.enhance(factor)
return np.array(pil_image, dtype=img.dtype)[..., ::-1]
else:
degenerated = np.zeros_like(img)
# Note manually convert the dtype to np.float32, to
# achieve as close results as PIL.ImageEnhance.Brightness.
# Set beta=1-factor, and gamma=0
brightened_img = cv2.addWeighted(
img.astype(np.float32), factor, degenerated.astype(np.float32),
1 - factor, 0)
brightened_img = np.clip(brightened_img, 0, 255)
return brightened_img.astype(img.dtype)
def adjust_contrast(img, factor=1., backend=None):
degenerated = np.zeros_like(img)
# Note manually convert the dtype to np.float32, to
# achieve as close results as PIL.ImageEnhance.Brightness.
# Set beta=1-factor, and gamma=0
brightened_img = cv2.addWeighted(
img.astype(np.float32), factor, degenerated.astype(np.float32),
1 - factor, 0)
brightened_img = np.clip(brightened_img, 0, 255)
return brightened_img.astype(img.dtype)
def adjust_contrast(img, factor=1.):
"""Adjust image contrast.
This function controls the contrast of an image. An
......@@ -261,38 +219,20 @@ def adjust_contrast(img, factor=1., backend=None):
Args:
img (ndarray): Image to be contrasted. BGR order.
factor (float): Same as :func:`mmcv.adjust_brightness`.
backend (str | None): The image processing backend type. Options are
`cv2`, `pillow`, `None`. If backend is None, the global
``imread_backend`` specified by ``mmcv.use_backend()`` will be
used. Defaults to None.
Returns:
ndarray: The contrasted image.
"""
if backend is None:
backend = imread_backend
if backend not in ['cv2', 'pillow']:
raise ValueError(f'backend: {backend} is not supported.'
f"Supported backends are 'cv2', 'pillow'")
if backend == 'pillow':
assert img.dtype == np.uint8, 'Pillow backend only support uint8 type'
# Image.fromarray defaultly supports RGB, not BGR.
pil_image = Image.fromarray(img[..., ::-1], mode='RGB')
enhancer = ImageEnhance.Contrast(pil_image)
pil_image = enhancer.enhance(factor)
return np.array(pil_image, dtype=img.dtype)[..., ::-1]
else:
gray_img = bgr2gray(img)
hist = np.histogram(gray_img, 256, (0, 255))[0]
mean = round(np.sum(gray_img) / np.sum(hist))
degenerated = (np.ones_like(img[..., 0]) * mean).astype(img.dtype)
degenerated = gray2bgr(degenerated)
contrasted_img = cv2.addWeighted(
img.astype(np.float32), factor, degenerated.astype(np.float32),
1 - factor, 0)
contrasted_img = np.clip(contrasted_img, 0, 255)
return contrasted_img.astype(img.dtype)
gray_img = bgr2gray(img)
hist = np.histogram(gray_img, 256, (0, 255))[0]
mean = round(np.sum(gray_img) / np.sum(hist))
degenerated = (np.ones_like(img[..., 0]) * mean).astype(img.dtype)
degenerated = gray2bgr(degenerated)
contrasted_img = cv2.addWeighted(
img.astype(np.float32), factor, degenerated.astype(np.float32),
1 - factor, 0)
contrasted_img = np.clip(contrasted_img, 0, 255)
return contrasted_img.astype(img.dtype)
def auto_contrast(img, cutoff=0):
......@@ -486,76 +426,3 @@ def clahe(img, clip_limit=40.0, tile_grid_size=(8, 8)):
clahe = cv2.createCLAHE(clip_limit, tile_grid_size)
return clahe.apply(np.array(img, dtype=np.uint8))
def adjust_hue(img: np.ndarray,
hue_factor: float,
backend: Optional[str] = None) -> np.ndarray:
"""Adjust hue of an image.
The image hue is adjusted by converting the image to HSV and cyclically
shifting the intensities in the hue channel (H). The image is then
converted back to original image mode.
`hue_factor` is the amount of shift in H channel and must be in the
interval `[-0.5, 0.5]`.
Modified from
https://github.com/pytorch/vision/blob/main/torchvision/
transforms/functional.py
Args:
img (ndarray): Image to be adjusted.
hue_factor (float): How much to shift the hue channel. Should be in
[-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
HSV space in positive and negative direction respectively.
0 means no shift. Therefore, both -0.5 and 0.5 will give an image
with complementary colors while 0 gives the original image.
backend (str | None): The image processing backend type. Options are
`cv2`, `pillow`, `None`. If backend is None, the global
``imread_backend`` specified by ``mmcv.use_backend()`` will be
used. Defaults to None.
Returns:
ndarray: Hue adjusted image.
"""
if backend is None:
backend = imread_backend
if backend not in ['cv2', 'pillow']:
raise ValueError(f'backend: {backend} is not supported.'
f"Supported backends are 'cv2', 'pillow'")
if not (-0.5 <= hue_factor <= 0.5):
raise ValueError(f'hue_factor:{hue_factor} is not in [-0.5, 0.5].')
if not (isinstance(img, np.ndarray) and (img.ndim in {2, 3})):
raise TypeError('img should be ndarray with dim=[2 or 3].')
if backend == 'pillow':
assert img.dtype == np.uint8, 'Pillow backend only support uint8 type'
# Image.fromarray defaultly supports RGB, not BGR.
pil_image = Image.fromarray(img[..., ::-1], mode='RGB')
input_mode = pil_image.mode
if input_mode in {'L', '1', 'I', 'F'}:
return pil_image
h, s, v = pil_image.convert('HSV').split()
np_h = np.array(h, dtype=np.uint8)
# uint8 addition take cares of rotation across boundaries
with np.errstate(over='ignore'):
np_h += np.uint8(hue_factor * 255)
h = Image.fromarray(np_h, 'L')
pil_image = Image.merge('HSV', (h, s, v)).convert(input_mode)
return np.array(pil_image, dtype=img.dtype)[..., ::-1]
else:
dtype = img.dtype
img = img.astype(np.uint8)
hsv_img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV_FULL)
h, s, v = cv2.split(hsv_img)
h = h.astype(np.uint8)
# uint8 addition take cares of rotation across boundaries
with np.errstate(over='ignore'):
h += np.uint8(hue_factor * 255)
hsv_img = cv2.merge([h, s, v])
return cv2.cvtColor(hsv_img, cv2.COLOR_HSV2BGR_FULL).astype(dtype)
{
"resnet50_caffe": "detectron/resnet50_caffe",
"resnet50_caffe_bgr": "detectron2/resnet50_caffe_bgr",
"resnet101_caffe": "detectron/resnet101_caffe",
"resnet101_caffe_bgr": "detectron2/resnet101_caffe_bgr"
}
{
"vgg11": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg11_batch256_imagenet_20210208-4271cd6c.pth",
"vgg13": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg13_batch256_imagenet_20210208-4d1d6080.pth",
"vgg16": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg16_batch256_imagenet_20210208-db26f1a5.pth",
"vgg19": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg19_batch256_imagenet_20210208-e6920e4a.pth",
"vgg11_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg11_bn_batch256_imagenet_20210207-f244902c.pth",
"vgg13_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg13_bn_batch256_imagenet_20210207-1a8b7864.pth",
"vgg16_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg16_bn_batch256_imagenet_20210208-7e55cd29.pth",
"vgg19_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg19_bn_batch256_imagenet_20210208-da620c4f.pth",
"resnet18": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth",
"resnet34": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth",
"resnet50": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth",
"resnet101": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_8xb32_in1k_20210831-539c63f8.pth",
"resnet152": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet152_8xb32_in1k_20210901-4d7582fa.pth",
"resnet50_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d50_b32x8_imagenet_20210531-db14775a.pth",
"resnet101_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d101_b32x8_imagenet_20210531-6e13bcd3.pth",
"resnet152_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d152_b32x8_imagenet_20210531-278cf22a.pth",
"resnext50_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext50_32x4d_b32x8_imagenet_20210429-56066e27.pth",
"resnext101_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext101_32x4d_b32x8_imagenet_20210506-e0fa3dd5.pth",
"resnext101_32x8d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext101_32x8d_b32x8_imagenet_20210506-23a247d5.pth",
"resnext152_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext152_32x4d_b32x8_imagenet_20210524-927787be.pth",
"se-resnet50": "https://download.openmmlab.com/mmclassification/v0/se-resnet/se-resnet50_batch256_imagenet_20200804-ae206104.pth",
"se-resnet101": "https://download.openmmlab.com/mmclassification/v0/se-resnet/se-resnet101_batch256_imagenet_20200804-ba5b51d4.pth",
"resnest50": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest50_imagenet_converted-1ebf0afe.pth",
"resnest101": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest101_imagenet_converted-032caa52.pth",
"resnest200": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest200_imagenet_converted-581a60f2.pth",
"resnest269": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest269_imagenet_converted-59930960.pth",
"shufflenet_v1": "https://download.openmmlab.com/mmclassification/v0/shufflenet_v1/shufflenet_v1_batch1024_imagenet_20200804-5d6cec73.pth",
"shufflenet_v2": "https://download.openmmlab.com/mmclassification/v0/shufflenet_v2/shufflenet_v2_batch1024_imagenet_20200812-5bf4721e.pth",
"mobilenet_v2": "https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth",
"mobilenet_v3_small": "https://download.openmmlab.com/mmclassification/v0/mobilenet_v3/convert/mobilenet_v3_small-8427ecf0.pth",
"mobilenet_v3_large": "https://download.openmmlab.com/mmclassification/v0/mobilenet_v3/convert/mobilenet_v3_large-3ea3c186.pth",
"repvgg_A0": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A0_3rdparty_4xb64-coslr-120e_in1k_20210909-883ab98c.pth",
"repvgg_A1": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A1_3rdparty_4xb64-coslr-120e_in1k_20210909-24003a24.pth",
"repvgg_A2": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A2_3rdparty_4xb64-coslr-120e_in1k_20210909-97d7695a.pth",
"repvgg_B0": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B0_3rdparty_4xb64-coslr-120e_in1k_20210909-446375f4.pth",
"repvgg_B1": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1_3rdparty_4xb64-coslr-120e_in1k_20210909-750cdf67.pth",
"repvgg_B1g2": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1g2_3rdparty_4xb64-coslr-120e_in1k_20210909-344f6422.pth",
"repvgg_B1g4": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1g4_3rdparty_4xb64-coslr-120e_in1k_20210909-d4c1a642.pth",
"repvgg_B2": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B2_3rdparty_4xb64-coslr-120e_in1k_20210909-bd6b937c.pth",
"repvgg_B2g4": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B2g4_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-7b7955f0.pth",
"repvgg_B3": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B3_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-dda968bf.pth",
"repvgg_B3g4": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B3g4_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-4e54846a.pth",
"repvgg_D2se": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-D2se_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-cf3139b7.pth",
"res2net101_w26": "https://download.openmmlab.com/mmclassification/v0/res2net/res2net101-w26-s4_3rdparty_8xb32_in1k_20210927-870b6c36.pth",
"res2net50_w14": "https://download.openmmlab.com/mmclassification/v0/res2net/res2net50-w14-s8_3rdparty_8xb32_in1k_20210927-bc967bf1.pth",
"res2net50_w26": "https://download.openmmlab.com/mmclassification/v0/res2net/res2net50-w26-s8_3rdparty_8xb32_in1k_20210927-f547a94b.pth",
"swin_tiny": "https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_tiny_224_b16x64_300e_imagenet_20210616_090925-66df6be6.pth",
"swin_small": "https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_small_224_b16x64_300e_imagenet_20210615_110219-7f9d988b.pth",
"swin_base": "https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_base_patch4_window7_224_22kto1k-f967f799.pth",
"swin_large": "https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_large_patch4_window7_224_22kto1k-5f0996db.pth",
"t2t_vit_t_14": "https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-14_3rdparty_8xb64_in1k_20210928-b7c09b62.pth",
"t2t_vit_t_19": "https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-19_3rdparty_8xb64_in1k_20210928-7f1478d5.pth",
"t2t_vit_t_24": "https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-24_3rdparty_8xb64_in1k_20210928-fe95a61b.pth",
"tnt_small": "https://download.openmmlab.com/mmclassification/v0/tnt/tnt-small-p16_3rdparty_in1k_20210903-c56ee7df.pth",
"vit_base_p16": "https://download.openmmlab.com/mmclassification/v0/vit/finetune/vit-base-p16_in21k-pre-3rdparty_ft-64xb64_in1k-384_20210928-98e8652b.pth",
"vit_base_p32": "https://download.openmmlab.com/mmclassification/v0/vit/finetune/vit-base-p32_in21k-pre-3rdparty_ft-64xb64_in1k-384_20210928-9cea8599.pth",
"vit_large_p16": "https://download.openmmlab.com/mmclassification/v0/vit/finetune/vit-large-p16_in21k-pre-3rdparty_ft-64xb64_in1k-384_20210928-b20ba619.pth"
}
{
"vgg16_caffe": "https://download.openmmlab.com/pretrain/third_party/vgg16_caffe-292e1171.pth",
"detectron/resnet50_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet50_caffe-788b5fa3.pth",
"detectron2/resnet50_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet50_msra-5891d200.pth",
"detectron/resnet101_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet101_caffe-3ad79236.pth",
"detectron2/resnet101_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet101_msra-6cc46731.pth",
"detectron2/resnext101_32x8d": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x8d-1516f1aa.pth",
"resnext50_32x4d": "https://download.openmmlab.com/pretrain/third_party/resnext50-32x4d-0ab1a123.pth",
"resnext101_32x4d": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d-a5af3160.pth",
"resnext101_64x4d": "https://download.openmmlab.com/pretrain/third_party/resnext101_64x4d-ee2c6f71.pth",
"contrib/resnet50_gn": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn_thangvubk-ad1730dd.pth",
"detectron/resnet50_gn": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn-9186a21c.pth",
"detectron/resnet101_gn": "https://download.openmmlab.com/pretrain/third_party/resnet101_gn-cac0ab98.pth",
"jhu/resnet50_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn_ws-15beedd8.pth",
"jhu/resnet101_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnet101_gn_ws-3e3c308c.pth",
"jhu/resnext50_32x4d_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnext50_32x4d_gn_ws-0d87ac85.pth",
"jhu/resnext101_32x4d_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d_gn_ws-34ac1a9e.pth",
"jhu/resnext50_32x4d_gn": "https://download.openmmlab.com/pretrain/third_party/resnext50_32x4d_gn-c7e8b754.pth",
"jhu/resnext101_32x4d_gn": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d_gn-ac3bb84e.pth",
"msra/hrnetv2_w18_small": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w18_small-b5a04e21.pth",
"msra/hrnetv2_w18": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w18-00eb2006.pth",
"msra/hrnetv2_w32": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w32-dc9eeb4f.pth",
"msra/hrnetv2_w40": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w40-ed0b031c.pth",
"msra/hrnetv2_w48": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w48-d2186c55.pth",
"bninception_caffe": "https://download.openmmlab.com/pretrain/third_party/bn_inception_caffe-ed2e8665.pth",
"kin400/i3d_r50_f32s2_k400": "https://download.openmmlab.com/pretrain/third_party/i3d_r50_f32s2_k400-2c57e077.pth",
"kin400/nl3d_r50_f32s2_k400": "https://download.openmmlab.com/pretrain/third_party/nl3d_r50_f32s2_k400-fa7e7caa.pth",
"res2net101_v1d_26w_4s": "https://download.openmmlab.com/pretrain/third_party/res2net101_v1d_26w_4s_mmdetv2-f0a600f9.pth",
"regnetx_400mf": "https://download.openmmlab.com/pretrain/third_party/regnetx_400mf-a5b10d96.pth",
"regnetx_800mf": "https://download.openmmlab.com/pretrain/third_party/regnetx_800mf-1f4be4c7.pth",
"regnetx_1.6gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_1.6gf-5791c176.pth",
"regnetx_3.2gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_3.2gf-c2599b0f.pth",
"regnetx_4.0gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_4.0gf-a88f671e.pth",
"regnetx_6.4gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_6.4gf-006af45d.pth",
"regnetx_8.0gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_8.0gf-3c68abe7.pth",
"regnetx_12gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_12gf-4c2a3350.pth",
"resnet18_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet18_v1c-b5776b93.pth",
"resnet50_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet50_v1c-2cccc1ad.pth",
"resnet101_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet101_v1c-e67eebb6.pth",
"mmedit/vgg16": "https://download.openmmlab.com/mmediting/third_party/vgg_state_dict.pth",
"mmedit/res34_en_nomixup": "https://download.openmmlab.com/mmediting/third_party/model_best_resnet34_En_nomixup.pth",
"mmedit/mobilenet_v2": "https://download.openmmlab.com/mmediting/third_party/mobilenet_v2.pth",
"contrib/mobilenet_v3_large": "https://download.openmmlab.com/pretrain/third_party/mobilenet_v3_large-bc2c3fd3.pth",
"contrib/mobilenet_v3_small": "https://download.openmmlab.com/pretrain/third_party/mobilenet_v3_small-47085aa1.pth",
"resnest50": "https://download.openmmlab.com/pretrain/third_party/resnest50_d2-7497a55b.pth",
"resnest101": "https://download.openmmlab.com/pretrain/third_party/resnest101_d2-f3b931b2.pth",
"resnest200": "https://download.openmmlab.com/pretrain/third_party/resnest200_d2-ca88e41f.pth",
"darknet53": "https://download.openmmlab.com/pretrain/third_party/darknet53-a628ea1b.pth",
"mmdet/mobilenet_v2": "https://download.openmmlab.com/mmdetection/v2.0/third_party/mobilenet_v2_batch256_imagenet-ff34753d.pth"
}
# Copyright (c) OpenMMLab. All rights reserved.
from .info import is_custom_op_loaded
from .symbolic import register_extra_symbolics
__all__ = ['register_extra_symbolics', 'is_custom_op_loaded']
# Copyright (c) OpenMMLab. All rights reserved.
import os
import torch
def is_custom_op_loaded():
flag = False
try:
from ..tensorrt import is_tensorrt_plugin_loaded
flag = is_tensorrt_plugin_loaded()
except (ImportError, ModuleNotFoundError):
pass
if not flag:
try:
from ..ops import get_onnxruntime_op_path
ort_lib_path = get_onnxruntime_op_path()
flag = os.path.exists(ort_lib_path)
except (ImportError, ModuleNotFoundError):
pass
return flag or torch.__version__ == 'parrots'
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.registry import TRANSFORMS # noqa: F401
# Copyright (c) OpenMMLab. All rights reserved.
"""Modified from https://github.com/pytorch/pytorch."""
import warnings
from functools import wraps
from sys import maxsize
import torch
import torch.onnx
# This import monkey-patches graph manipulation methods on Graph, used for the
# ONNX symbolics
import torch.onnx.utils
from torch._C import ListType
# ---------------------------------------------------------------------------------
# Helper functions
# ---------------------------------------------------------------------------------
# Save some builtins as locals, because we'll shadown them below
_sum = sum
def _parse_arg(value, desc):
if desc == 'none':
return value
if desc == 'v' or not _is_value(value):
return value
if value.node().mustBeNone():
return None
if value.node().kind() == 'onnx::Constant':
tval = value.node()['value']
if desc == 'i':
return int(tval)
elif desc == 'f':
return float(tval)
elif desc == 'b':
return bool(tval)
elif desc == 's':
return str(tval)
elif desc == 't':
return tval
elif desc == 'is':
return [int(v) for v in tval]
elif desc == 'fs':
return [float(v) for v in tval]
else:
raise RuntimeError(
"ONNX symbolic doesn't know to interpret Constant node")
elif value.node().kind() == 'prim::ListConstruct':
if desc == 'is':
for v in value.node().inputs():
if v.node().kind() != 'onnx::Constant':
raise RuntimeError(
"Failed to export an ONNX attribute '" +
v.node().kind() +
"', since it's not constant, please try to make "
'things (e.g., kernel size) static if possible')
return [int(v.node()['value']) for v in value.node().inputs()]
else:
raise RuntimeError(
"ONNX symbolic doesn't know to interpret ListConstruct node")
raise RuntimeError('Unexpected node type: {}'.format(value.node().kind()))
def _maybe_get_const(value, desc):
if _is_value(value) and value.node().kind() == 'onnx::Constant':
return _parse_arg(value, desc)
return value
def _maybe_get_scalar(value):
value_t = _maybe_get_const(value, 't')
if isinstance(value_t, torch.Tensor) and value_t.shape == ():
return value_t
return value
def _get_const(value, desc, arg_name):
if _is_value(value) and value.node().kind() not in ('onnx::Constant',
'prim::Constant'):
raise RuntimeError('ONNX symbolic expected a constant'
' value of the {} argument, got `{}`'.format(
arg_name, value))
return _parse_arg(value, desc)
def _unpack_list(list_value):
list_node = list_value.node()
assert list_node.kind() == 'prim::ListConstruct'
return list(list_node.inputs())
# Check if list_value is output from prim::ListConstruct
# This is usually called before _unpack_list to ensure the list can be
# unpacked.
def _is_packed_list(list_value):
return _is_value(
list_value) and list_value.node().kind() == 'prim::ListConstruct'
def parse_args(*arg_descriptors):
def decorator(fn):
fn._arg_descriptors = arg_descriptors
def wrapper(g, *args):
# some args may be optional, so the length may be smaller
assert len(arg_descriptors) >= len(args)
args = [
_parse_arg(arg, arg_desc)
for arg, arg_desc in zip(args, arg_descriptors)
]
return fn(g, *args)
# In Python 2 functools.wraps chokes on partially applied functions, so
# we need this as a workaround
try:
wrapper = wraps(fn)(wrapper)
except Exception:
pass
return wrapper
return decorator
def _scalar(x):
"""Convert a scalar tensor into a Python value."""
assert x.numel() == 1
return x.item()
def _if_scalar_type_as(g, self, tensor):
"""Convert self into the same type of tensor, as necessary."""
if isinstance(self, torch._C.Value):
return self
scalar_type = tensor.type().scalarType()
if scalar_type:
ty = scalar_type.lower()
return getattr(self, ty)()
return self
def _is_none(x):
return x.node().mustBeNone()
def _is_value(x):
return isinstance(x, torch._C.Value)
def _is_tensor_list(x):
return x.type().isSubtypeOf(ListType.ofTensors())
def _unimplemented(op, msg):
warnings.warn('ONNX export failed on ' + op + ' because ' + msg +
' not supported')
def _try_get_scalar_type(*args):
for arg in args:
try:
return arg.type().scalarType()
except RuntimeError:
pass
return None
def _topk_helper(g, input, k, dim, largest=True, sorted=False, out=None):
if out is not None:
_unimplemented('TopK', 'Out parameter is not supported')
if not _is_value(k):
k = g.op('Constant', value_t=torch.tensor([k], dtype=torch.int64))
else:
k = g.op('Reshape', k, g.op('Constant', value_t=torch.tensor([1])))
return g.op(
'TopK',
input,
k,
axis_i=dim,
largest_i=largest,
sorted_i=sorted,
outputs=2)
def _slice_helper(g,
input,
axes,
starts,
ends,
steps=None,
dynamic_slice=False):
# TODO(ruobing): add support for opset<10
from torch.onnx.symbolic_opset10 import _slice
return _slice(g, input, axes, starts, ends, steps, dynamic_slice)
def _unsqueeze_helper(g, input, dim):
from torch.onnx.symbolic_opset9 import unsqueeze
return unsqueeze(g, input, dim)
def _interpolate_size_to_scales(g, input, output_size, dim):
output_size = _maybe_get_const(output_size, 'is')
if _is_value(output_size):
offset = 2
offsets = g.op(
'Constant', value_t=torch.ones(offset, dtype=torch.float32))
dividend = g.op(
'Cast', output_size, to_i=cast_pytorch_to_onnx['Float'])
divisor = _slice_helper(
g, g.op('Shape', input), axes=[0], ends=[maxsize], starts=[offset])
divisor = g.op('Cast', divisor, to_i=cast_pytorch_to_onnx['Float'])
scale_dims = g.op('Div', dividend, divisor)
scales = g.op('Concat', offsets, scale_dims, axis_i=0)
else:
scales_constant = [
1. if i < 2 else float(output_size[-(dim - i)]) /
float(input.type().sizes()[-(dim - i)]) for i in range(0, dim)
]
scales = g.op(
'Constant',
value_t=torch.tensor(scales_constant, dtype=torch.float32))
return scales
def _interpolate_get_scales_if_available(g, scales):
if len(scales) == 0:
return None
# scales[0] is NoneType in Pytorch == 1.5.1
# scales[0] is TensorType with sizes = [] in Pytorch == 1.6.0
# scales[0] is ListType in Pytorch == 1.7.0
# scales[0] is TensorType with sizes = [2] in Pytorch == 1.8.0
scale_desc = 'fs' if scales[0].type().kind() == 'ListType' or (
scales[0].type().kind() == 'TensorType' and
(sum(scales[0].type().sizes()) > 1)) else 'f'
available_scales = _maybe_get_const(
scales[0], scale_desc) != -1 and not _is_none(scales[0])
if not available_scales:
return None
offsets = g.op('Constant', value_t=torch.ones(2, dtype=torch.float32))
if scale_desc == 'fs':
scales_list = g.op(
'Constant',
value_t=torch.tensor(_maybe_get_const(scales[0], scale_desc)))
# modify to support PyTorch==1.7.0
# https://github.com/pytorch/pytorch/blob/75ee5756715e7161314ce037474843b68f69fc04/torch/onnx/symbolic_helper.py#L375 # noqa: E501
scales = g.op('Concat', offsets, scales_list, axis_i=0)
else:
# for PyTorch < 1.7.0
scales_list = []
for scale in scales:
unsqueezed_scale = _unsqueeze_helper(g, scale, 0)
# ONNX only supports float for the scales. double -> float.
unsqueezed_scale = g.op(
'Cast', unsqueezed_scale, to_i=cast_pytorch_to_onnx['Float'])
scales_list.append(unsqueezed_scale)
scales = g.op('Concat', offsets, *scales_list, axis_i=0)
return scales
def _get_interpolate_attributes(g, mode, args):
if mode == 'nearest':
align_corners = None
scales = args[0:]
else:
align_corners = args[0]
scales = args[1:]
scales = _interpolate_get_scales_if_available(g, scales)
return scales, align_corners
def _interpolate_get_scales(g, scale_factor, dim):
offsets = g.op('Constant', value_t=torch.ones(2, dtype=torch.float32))
if isinstance(scale_factor.type(), torch._C.ListType):
return g.op('Concat', offsets, scale_factor, axis_i=0)
else:
scale_factor = _unsqueeze_helper(g, scale_factor, 0)
scale_factor = g.op(
'Cast', scale_factor, to_i=cast_pytorch_to_onnx['Float'])
scales = [scale_factor for i in range(dim - 2)]
scale_factor = g.op('Concat', offsets, *scales, axis_i=0)
return scale_factor
def _size_helper(g, self, dim):
full_shape = g.op('Shape', self)
from torch.onnx.symbolic_opset9 import select
return select(g, full_shape, g.op('Constant', value_t=torch.tensor([0])),
dim)
def _avgpool_helper(tuple_fn, padding, kernel_size, stride, divisor_override,
name):
if divisor_override and divisor_override.node().kind() != 'prim::Constant':
return _unimplemented(name, 'divisor_override')
if not stride:
stride = kernel_size
padding = tuple(tuple_fn(padding))
return padding
# Metaprogram symbolics for each ATen native specialized cast operator.
# For e.g. we specify a function named `_cast_uint8_t` that instantiates an
# ONNX cast node with `to` attribute 'UINT8'
#
# TODO: remove these once we support Type's in the JIT IR and we can once again
# use the unified toType operator
cast_pytorch_to_onnx = {
'Byte': torch.onnx.TensorProtoDataType.UINT8,
'Char': torch.onnx.TensorProtoDataType.INT8,
'Double': torch.onnx.TensorProtoDataType.DOUBLE,
'Float': torch.onnx.TensorProtoDataType.FLOAT,
'Half': torch.onnx.TensorProtoDataType.FLOAT16,
'Int': torch.onnx.TensorProtoDataType.INT32,
'Long': torch.onnx.TensorProtoDataType.INT64,
'Short': torch.onnx.TensorProtoDataType.INT16,
'Bool': torch.onnx.TensorProtoDataType.BOOL,
'ComplexFloat': torch.onnx.TensorProtoDataType.COMPLEX64,
'ComplexDouble': torch.onnx.TensorProtoDataType.COMPLEX128,
'Undefined': torch.onnx.TensorProtoDataType.UNDEFINED,
}
# Global set to store the list of quantized operators in the network.
# This is currently only used in the conversion of quantized ops from PT
# -> C2 via ONNX.
_quantized_ops = set()
# Copyright (c) OpenMMLab. All rights reserved.
"""Modified from https://github.com/pytorch/pytorch."""
import os
import numpy as np
import torch
from torch.nn.modules.utils import _pair, _single, _triple
from torch.onnx.symbolic_helper import parse_args
from torch.onnx.symbolic_registry import register_op
from .onnx_utils import symbolic_helper as sym_help
def _interpolate(name, dim, interpolate_mode):
def symbolic_fn(g, input, output_size, *args):
scales, align_corners = sym_help._get_interpolate_attributes(
g, interpolate_mode, args)
align_corners = sym_help._maybe_get_scalar(align_corners)
transformation_mode = 'asymmetric' \
if interpolate_mode == 'nearest' \
else 'align_corners' if align_corners else 'pytorch_half_pixel'
empty_tensor = g.op(
'Constant', value_t=torch.tensor([], dtype=torch.float32))
if scales is None:
if 'ONNX_BACKEND' in os.environ and os.environ[
'ONNX_BACKEND'] == 'TensorRT':
input_size = input.type().sizes()
# slice the first two dim
input_size = input_size[:2]
# convert output_size to int type
output_size = sym_help._maybe_get_const(output_size, 'is')
input_size.extend(output_size)
output_size = g.op(
'Constant',
value_t=torch.tensor(input_size, dtype=torch.int64))
else:
input_size = g.op('Shape', input)
input_size_beg = sym_help._slice_helper(
g, input_size, axes=[0], ends=[2], starts=[0])
output_size = g.op(
'Cast',
output_size,
to_i=sym_help.cast_pytorch_to_onnx['Long'])
output_size = g.op(
'Concat', input_size_beg, output_size, axis_i=0)
scales = g.op(
'Constant', value_t=torch.tensor([], dtype=torch.float32))
return g.op(
'Resize',
input,
empty_tensor,
# roi only takes effect with
# coordinate_transformation_mode="tf_crop_and_resize"
scales, # scales is not needed since we are sending out_size
output_size,
coordinate_transformation_mode_s=transformation_mode,
cubic_coeff_a_f=-0.75, # only valid when mode="cubic"
mode_s=interpolate_mode, # nearest, linear, or cubic
nearest_mode_s='floor') # only valid when mode="nearest"
else:
return g.op(
'Resize',
input,
empty_tensor,
# roi only takes effect with
# coordinate_transformation_mode="tf_crop_and_resize"
scales, # scales is not needed since we are sending out_size
coordinate_transformation_mode_s=transformation_mode,
cubic_coeff_a_f=-0.75, # only valid when mode="cubic"
mode_s=interpolate_mode, # nearest, linear, or cubic
nearest_mode_s='floor') # only valid when mode="nearest"
return symbolic_fn
upsample_nearest1d = _interpolate('upsample_nearest1d', 3, 'nearest')
upsample_nearest2d = _interpolate('upsample_nearest2d', 4, 'nearest')
upsample_nearest3d = _interpolate('upsample_nearest3d', 5, 'nearest')
upsample_linear1d = _interpolate('upsample_linear1d', 3, 'linear')
upsample_bilinear2d = _interpolate('upsample_bilinear2d', 4, 'linear')
upsample_trilinear3d = _interpolate('upsample_trilinear3d', 5, 'linear')
upsample_bicubic2d = _interpolate('upsample_bicubic2d', 4, 'cubic')
@parse_args('v', 'v', 'i', 'i', 'i', 'none')
def topk(g, self, k, dim, largest, sorted, out=None):
return sym_help._topk_helper(
g, self, k, dim, largest=largest, sorted=sorted, out=out)
def masked_select(g, self, mask):
from torch.onnx.symbolic_opset9 import expand_as, nonzero
index = nonzero(g, expand_as(g, mask, self))
return g.op('GatherND', self, index)
def _prepare_onnx_paddings(g, dim, pad):
pad_len = torch.onnx.symbolic_opset9.size(
g, pad, g.op('Constant', value_t=torch.tensor([0])))
# Set extension = [0] * (dim * 2 - len(pad))
extension = g.op(
'Sub',
g.op('Mul',
g.op('Constant', value_t=torch.tensor(dim, dtype=torch.int64)),
g.op('Constant', value_t=torch.tensor(2, dtype=torch.int64))),
pad_len)
pad = g.op('Cast', pad, to_i=sym_help.cast_pytorch_to_onnx['Long'])
paddings = g.op(
'Concat',
pad,
g.op(
'ConstantOfShape',
extension,
value_t=torch.tensor([0], dtype=torch.int64)),
axis_i=0)
paddings = g.op('Reshape', paddings,
g.op('Constant', value_t=torch.tensor([-1, 2])))
paddings = g.op(
'Transpose',
torch.onnx.symbolic_opset10.flip(g, paddings, [0]),
perm_i=[1, 0])
paddings = g.op('Reshape', paddings,
g.op('Constant', value_t=torch.tensor([-1])))
padding_c = g.op(
'Cast', paddings, to_i=sym_help.cast_pytorch_to_onnx['Long'])
return padding_c
def constant_pad_nd(g, input, padding, value=None):
mode = 'constant'
value = sym_help._maybe_get_scalar(value)
value = sym_help._if_scalar_type_as(g, value, input)
pad = _prepare_onnx_paddings(g, input.type().dim(), padding)
return g.op('Pad', input, pad, value, mode_s=mode)
def reflection_pad(g, input, padding):
mode = 'reflect'
paddings = _prepare_onnx_paddings(g, input.type().dim(), padding)
return g.op('Pad', input, paddings, mode_s=mode)
reflection_pad1d = reflection_pad
reflection_pad2d = reflection_pad
reflection_pad3d = reflection_pad
def _avg_pool(name, tuple_fn):
@parse_args('v', 'is', 'is', 'is', 'i', 'i', 'none')
def symbolic_fn(g,
input,
kernel_size,
stride,
padding,
ceil_mode,
count_include_pad,
divisor_override=None):
padding = sym_help._avgpool_helper(tuple_fn, padding, kernel_size,
stride, divisor_override, name)
if not stride:
stride = kernel_size
if count_include_pad:
input = g.op(
'Pad',
input,
g.op(
'Constant',
value_t=torch.tensor(((0, ) * 2 + padding) * 2)),
mode_s='constant')
padding = (0, ) * len(padding)
output = g.op(
'AveragePool',
input,
kernel_shape_i=tuple_fn(kernel_size),
strides_i=tuple_fn(stride),
pads_i=padding * 2,
ceil_mode_i=ceil_mode)
return output
return symbolic_fn
avg_pool1d = _avg_pool('avg_pool1d', _single)
avg_pool2d = _avg_pool('avg_pool2d', _pair)
avg_pool3d = _avg_pool('avg_pool3d', _triple)
def _get_im2col_indices_along_dim(g, input_d, kernel_size_d, dilation_d,
padding_d, stride_d):
# Input is always 4-D (N, C, H, W)
# Calculate indices of sliding blocks along spatial dimension
# Slide kernel over input each dim d:
# each dimension d ranges from 0 to
# input[d]+2xpadding[d]-dilation[d]x(kernel_size[d]-1)
# with steps = stride
blocks_d = g.op('Add', input_d,
g.op('Constant', value_t=torch.tensor(padding_d * 2)))
blocks_d = g.op(
'Sub', blocks_d,
g.op(
'Constant',
value_t=torch.tensor(dilation_d * (kernel_size_d - 1))))
# Stride kernel over input and find starting indices along dim d
blocks_d_indices = g.op('Range', g.op('Constant', value_t=torch.tensor(0)),
blocks_d,
g.op('Constant', value_t=torch.tensor(stride_d)))
# Apply dilation on kernel and find its indices along dim d
kernel_grid = np.arange(0, kernel_size_d * dilation_d, dilation_d)
kernel_grid = g.op('Constant', value_t=torch.tensor([kernel_grid]))
# Broadcast and add kernel staring positions (indices) with
# kernel_grid along dim d, to get block indices along dim d
blocks_d_indices = g.op(
'Unsqueeze', blocks_d_indices, axes_i=[0]) # Reshape to [1, -1]
kernel_mask = g.op('Reshape', kernel_grid,
g.op('Constant', value_t=torch.tensor([-1, 1])))
block_mask = g.op('Add', blocks_d_indices, kernel_mask)
return block_mask
def _get_im2col_padded_input(g, input, padding_h, padding_w):
# Input is always 4-D tensor (N, C, H, W)
# Padding tensor has the following format: (padding_h, padding_w)
# Reshape the padding to follow ONNX format:
# (dim1_begin, dim2_begin,...,dim1_end, dim2_end,...)
pad = g.op(
'Constant', value_t=torch.LongTensor([0, 0, padding_h, padding_w] * 2))
return g.op('Pad', input, pad)
def _get_im2col_output_shape(g, input, kernel_h, kernel_w):
batch_dim = size(g, input, g.op('Constant', value_t=torch.tensor(0)))
channel_dim = size(g, input, g.op('Constant', value_t=torch.tensor(1)))
channel_unfolded = g.op(
'Mul', channel_dim,
g.op('Constant', value_t=torch.tensor(kernel_h * kernel_w)))
return g.op(
'Concat',
g.op('Unsqueeze', batch_dim, axes_i=[0]),
g.op('Unsqueeze', channel_unfolded, axes_i=[0]),
g.op('Constant', value_t=torch.tensor([-1])),
axis_i=0)
def size(g, self, dim=None):
if dim is None:
return g.op('Shape', self)
return sym_help._size_helper(g, self, dim)
@parse_args('v', 'is', 'is', 'is', 'is')
def im2col(g, input, kernel_size, dilation, padding, stride):
# Input is always 4-D tensor (N, C, H, W)
# All other args are int[2]
input_h = size(g, input, g.op('Constant', value_t=torch.tensor(2)))
input_w = size(g, input, g.op('Constant', value_t=torch.tensor(3)))
stride_h, stride_w = stride[0], stride[1]
padding_h, padding_w = padding[0], padding[1]
dilation_h, dilation_w = dilation[0], dilation[1]
kernel_h, kernel_w = kernel_size[0], kernel_size[1]
blocks_row_indices = _get_im2col_indices_along_dim(g, input_h, kernel_h,
dilation_h, padding_h,
stride_h)
blocks_col_indices = _get_im2col_indices_along_dim(g, input_w, kernel_w,
dilation_w, padding_w,
stride_w)
output_shape = _get_im2col_output_shape(g, input, kernel_h, kernel_w)
padded_input = _get_im2col_padded_input(g, input, padding_h, padding_w)
output = g.op('Gather', padded_input, blocks_row_indices, axis_i=2)
output = g.op('Gather', output, blocks_col_indices, axis_i=4)
output = g.op('Transpose', output, perm_i=[0, 1, 2, 4, 3, 5])
return g.op('Reshape', output, output_shape)
@parse_args('v', 'i')
def one_hot(g, self, num_classes):
values = g.op('Constant', value_t=torch.LongTensor([0, 1]))
depth = g.op('Constant', value_t=torch.LongTensor([num_classes]))
return g.op('OneHot', self, depth, values, axis_i=-1)
@parse_args('v', 'i', 'none')
def softmax(g, input, dim, dtype=None):
input_dim = input.type().dim()
if input_dim:
# TODO: remove this as onnx opset 11 spec allows negative axes
if dim < 0:
dim = input_dim + dim
if input_dim == dim + 1:
softmax = g.op('Softmax', input, axis_i=dim)
if dtype and dtype.node().kind() != 'prim::Constant':
parsed_dtype = sym_help._get_const(dtype, 'i', 'dtype')
softmax = g.op(
'Cast',
softmax,
to_i=sym_help.scalar_type_to_onnx[parsed_dtype])
return softmax
max_value = g.op('ReduceMax', input, axes_i=[dim], keepdims_i=1)
input = g.op('Sub', input, max_value)
exp = g.op('Exp', input)
sum = g.op('ReduceSum', exp, axes_i=[dim])
softmax = g.op('Div', exp, sum)
if dtype and dtype.node().kind() != 'prim::Constant':
parsed_dtype = sym_help._get_const(dtype, 'i', 'dtype')
softmax = g.op(
'Cast', softmax, to_i=sym_help.scalar_type_to_onnx[parsed_dtype])
return softmax
def _adaptive_pool(name, type, tuple_fn, fn=None):
@parse_args('v', 'is')
def symbolic_fn(g, input, output_size):
if output_size == [1] * len(output_size) and type == 'AveragePool':
return g.op('GlobalAveragePool', input)
if not input.isCompleteTensor():
if output_size == [1] * len(output_size):
return g.op('GlobalMaxPool', input), None
raise NotImplementedError(
'[Adaptive pool]:input size not accessible')
dim = input.type().sizes()[2:]
if output_size == [1] * len(output_size) and type == 'MaxPool':
return g.op('GlobalMaxPool', input), None
# compute stride = floor(input_size / output_size)
s = [int(dim[i] / output_size[i]) for i in range(0, len(dim))]
# compute kernel_size = input_size - (output_size - 1) * stride
k = [dim[i] - (output_size[i] - 1) * s[i] for i in range(0, len(dim))]
# call max_poolxd_with_indices to get indices in the output
if type == 'MaxPool':
return fn(g, input, k, k, (0, ) * len(dim), (1, ) * len(dim),
False)
output = g.op(
type,
input,
kernel_shape_i=tuple_fn(k),
strides_i=tuple_fn(s),
ceil_mode_i=False)
return output
return symbolic_fn
adaptive_avg_pool1d = _adaptive_pool('adaptive_avg_pool1d', 'AveragePool',
_single)
adaptive_avg_pool2d = _adaptive_pool('adaptive_avg_pool2d', 'AveragePool',
_pair)
adaptive_avg_pool3d = _adaptive_pool('adaptive_avg_pool3d', 'AveragePool',
_triple)
def new_full(g,
self,
size,
fill_value,
dtype,
layout,
device,
pin_memory=False):
from torch.onnx.symbolic_opset9 import full
if dtype is None and self.isCompleteTensor():
dtype = self.type().scalarType()
dtype = sym_help.scalar_type_to_onnx.index(
sym_help.cast_pytorch_to_onnx[dtype])
return full(g, size, fill_value, dtype, layout, device, pin_memory)
@parse_args('v', 'v', 'i', 'i', 'i')
def grid_sampler(g,
input,
grid,
interpolation_mode,
padding_mode,
align_corners=False):
return g.op(
'mmcv::grid_sampler',
input,
grid,
interpolation_mode_i=interpolation_mode,
padding_mode_i=padding_mode,
align_corners_i=align_corners)
@parse_args('v', 'i')
def cummax(g, input, dim):
return g.op('mmcv::cummax', input, dim_i=dim, outputs=2)
@parse_args('v', 'i')
def cummin(g, input, dim):
return g.op('mmcv::cummin', input, dim_i=dim, outputs=2)
@parse_args('v', 'v', 'is')
def roll(g, input, shifts, dims):
from torch.onnx.symbolic_opset9 import squeeze
from packaging import version
input_shape = g.op('Shape', input)
need_flatten = len(dims) == 0
# If dims is not specified, the tensor will be flattened before
# rolling and then restored to the original shape.
if need_flatten:
resize_shape = input_shape
input = g.op('Reshape', input,
g.op('Constant', value_t=torch.LongTensor([1, -1])))
input_shape = g.op('Shape', input)
dims = [1]
for index, dim in enumerate(dims):
end_size = sym_help._slice_helper(
g, input_shape, axes=[0], ends=[dim + 1], starts=[dim])
shift_size = sym_help._slice_helper(
g, shifts, axes=[0], ends=[index + 1], starts=[index])
slice_size = g.op('Sub', end_size, shift_size)
# Can not use Mod because tensorrt does not support
div_size = g.op('Div', slice_size, end_size)
slice_size = g.op('Sub', slice_size, g.op('Mul', end_size, div_size))
if version.parse(torch.__version__) >= version.parse('1.7.0'):
# add dim=0 for pytorch 1.9.0
end_size = squeeze(g, end_size, 0)
slice_size = squeeze(g, slice_size, 0)
else:
end_size = g.op('Squeeze', end_size)
slice_size = g.op('Squeeze', slice_size)
dim = torch.LongTensor([dim])
input_slice0 = sym_help._slice_helper(
g,
input,
axes=dim,
starts=torch.LongTensor([0]),
ends=slice_size,
dynamic_slice=True)
input_slice1 = sym_help._slice_helper(
g,
input,
axes=dim,
ends=end_size,
starts=slice_size,
dynamic_slice=True)
input = g.op('Concat', input_slice1, input_slice0, axis_i=dim)
if need_flatten:
input = g.op('Reshape', input, resize_shape)
return input
def register_extra_symbolics(opset=11):
register_op('one_hot', one_hot, '', opset)
register_op('im2col', im2col, '', opset)
register_op('topk', topk, '', opset)
register_op('softmax', softmax, '', opset)
register_op('constant_pad_nd', constant_pad_nd, '', opset)
register_op('reflection_pad1d', reflection_pad1d, '', opset)
register_op('reflection_pad2d', reflection_pad2d, '', opset)
register_op('reflection_pad3d', reflection_pad3d, '', opset)
register_op('avg_pool1d', avg_pool1d, '', opset)
register_op('avg_pool2d', avg_pool2d, '', opset)
register_op('avg_pool3d', avg_pool3d, '', opset)
register_op('adaptive_avg_pool1d', adaptive_avg_pool1d, '', opset)
register_op('adaptive_avg_pool2d', adaptive_avg_pool2d, '', opset)
register_op('adaptive_avg_pool3d', adaptive_avg_pool3d, '', opset)
register_op('masked_select', masked_select, '', opset)
register_op('upsample_nearest1d', upsample_nearest1d, '', opset)
register_op('upsample_nearest2d', upsample_nearest2d, '', opset)
register_op('upsample_nearest3d', upsample_nearest3d, '', opset)
register_op('upsample_linear1d', upsample_linear1d, '', opset)
register_op('upsample_bilinear2d', upsample_bilinear2d, '', opset)
register_op('upsample_trilinear3d', upsample_trilinear3d, '', opset)
register_op('upsample_bicubic2d', upsample_bicubic2d, '', opset)
register_op('new_full', new_full, '', opset)
register_op('grid_sampler', grid_sampler, '', opset)
register_op('cummax', cummax, '', opset)
register_op('cummin', cummin, '', opset)
register_op('roll', roll, '', opset)
# Copyright (c) OpenMMLab. All rights reserved.
from .active_rotated_filter import active_rotated_filter
from .assign_score_withk import assign_score_withk
from .ball_query import ball_query
from .bbox import bbox_overlaps
from .bezier_align import BezierAlign, bezier_align
from .bias_act import bias_act
from .border_align import BorderAlign, border_align
from .box_iou_quadri import box_iou_quadri
from .box_iou_rotated import box_iou_rotated
from .carafe import CARAFE, CARAFENaive, CARAFEPack, carafe, carafe_naive
from .cc_attention import CrissCrossAttention
from .chamfer_distance import chamfer_distance
from .contour_expand import contour_expand
from .conv2d_gradfix import conv2d, conv_transpose2d
from .convex_iou import convex_giou, convex_iou
from .corner_pool import CornerPool
from .correlation import Correlation
from .deform_conv import DeformConv2d, DeformConv2dPack, deform_conv2d
......@@ -23,8 +16,6 @@ from .deprecated_wrappers import Conv2d_deprecated as Conv2d
from .deprecated_wrappers import ConvTranspose2d_deprecated as ConvTranspose2d
from .deprecated_wrappers import Linear_deprecated as Linear
from .deprecated_wrappers import MaxPool2d_deprecated as MaxPool2d
from .diff_iou_rotated import diff_iou_rotated_2d, diff_iou_rotated_3d
from .filtered_lrelu import filtered_lrelu
from .focal_loss import (SigmoidFocalLoss, SoftmaxFocalLoss,
sigmoid_focal_loss, softmax_focal_loss)
from .furthest_point_sample import (furthest_point_sample,
......@@ -32,46 +23,35 @@ from .furthest_point_sample import (furthest_point_sample,
from .fused_bias_leakyrelu import FusedBiasLeakyReLU, fused_bias_leakyrelu
from .gather_points import gather_points
from .group_points import GroupAll, QueryAndGroup, grouping_operation
from .info import get_compiler_version, get_compiling_cuda_version
from .iou3d import (boxes_iou3d, boxes_iou_bev, boxes_overlap_bev, nms3d,
nms3d_normal, nms_bev, nms_normal_bev)
from .info import (get_compiler_version, get_compiling_cuda_version,
get_onnxruntime_op_path)
from .iou3d import boxes_iou_bev, nms_bev, nms_normal_bev
from .knn import knn
from .masked_conv import MaskedConv2d, masked_conv2d
from .min_area_polygons import min_area_polygons
from .modulated_deform_conv import (ModulatedDeformConv2d,
ModulatedDeformConv2dPack,
modulated_deform_conv2d)
from .multi_scale_deform_attn import MultiScaleDeformableAttention
from .nms import batched_nms, nms, nms_match, nms_quadri, nms_rotated, soft_nms
from .nms import batched_nms, nms, nms_match, nms_rotated, soft_nms
from .pixel_group import pixel_group
from .point_sample import (SimpleRoIAlign, point_sample,
rel_roi_point_to_rel_img_point)
from .points_in_boxes import (points_in_boxes_all, points_in_boxes_cpu,
points_in_boxes_part)
from .points_in_polygons import points_in_polygons
from .points_sampler import PointsSampler
from .prroi_pool import PrRoIPool, prroi_pool
from .psa_mask import PSAMask
from .riroi_align_rotated import RiRoIAlignRotated, riroi_align_rotated
from .roi_align import RoIAlign, roi_align
from .roi_align_rotated import RoIAlignRotated, roi_align_rotated
from .roi_pool import RoIPool, roi_pool
from .roiaware_pool3d import RoIAwarePool3d
from .roipoint_pool3d import RoIPointPool3d
from .rotated_feature_align import rotated_feature_align
from .saconv import SAConv2d
from .scatter_points import DynamicScatter, dynamic_scatter
from .sparse_conv import (SparseConv2d, SparseConv3d, SparseConvTranspose2d,
SparseConvTranspose3d, SparseInverseConv2d,
SparseInverseConv3d, SubMConv2d, SubMConv3d)
from .sparse_modules import SparseModule, SparseSequential
from .sparse_pool import SparseMaxPool2d, SparseMaxPool3d
from .sparse_structure import SparseConvTensor, scatter_nd
from .sync_bn import SyncBatchNorm
from .three_interpolate import three_interpolate
from .three_nn import three_nn
from .tin_shift import TINShift, tin_shift
from .upfirdn2d import filter2d, upfirdn2d, upsample2d
from .upfirdn2d import upfirdn2d
from .voxelize import Voxelization, voxelization
__all__ = [
......@@ -80,32 +60,22 @@ __all__ = [
'deform_conv2d', 'DeformRoIPool', 'DeformRoIPoolPack',
'ModulatedDeformRoIPoolPack', 'deform_roi_pool', 'SigmoidFocalLoss',
'SoftmaxFocalLoss', 'sigmoid_focal_loss', 'softmax_focal_loss',
'get_compiler_version', 'get_compiling_cuda_version', 'MaskedConv2d',
'masked_conv2d', 'ModulatedDeformConv2d', 'ModulatedDeformConv2dPack',
'get_compiler_version', 'get_compiling_cuda_version',
'get_onnxruntime_op_path', 'MaskedConv2d', 'masked_conv2d',
'ModulatedDeformConv2d', 'ModulatedDeformConv2dPack',
'modulated_deform_conv2d', 'batched_nms', 'nms', 'soft_nms', 'nms_match',
'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool', 'SyncBatchNorm', 'Conv2d',
'ConvTranspose2d', 'Linear', 'MaxPool2d', 'CrissCrossAttention', 'PSAMask',
'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign',
'SAConv2d', 'TINShift', 'tin_shift', 'assign_score_withk',
'box_iou_rotated', 'box_iou_quadri', 'RoIPointPool3d', 'nms_rotated',
'knn', 'ball_query', 'upfirdn2d', 'FusedBiasLeakyReLU',
'fused_bias_leakyrelu', 'rotated_feature_align', 'RiRoIAlignRotated',
'riroi_align_rotated', 'RoIAlignRotated', 'roi_align_rotated',
'pixel_group', 'QueryAndGroup', 'GroupAll', 'grouping_operation',
'contour_expand', 'three_nn', 'three_interpolate',
'MultiScaleDeformableAttention', 'BorderAlign', 'border_align',
'gather_points', 'furthest_point_sample', 'nms_quadri',
'box_iou_rotated', 'RoIPointPool3d', 'nms_rotated', 'knn', 'ball_query',
'upfirdn2d', 'FusedBiasLeakyReLU', 'fused_bias_leakyrelu',
'RoIAlignRotated', 'roi_align_rotated', 'pixel_group', 'QueryAndGroup',
'GroupAll', 'grouping_operation', 'contour_expand', 'three_nn',
'three_interpolate', 'MultiScaleDeformableAttention', 'BorderAlign',
'border_align', 'gather_points', 'furthest_point_sample',
'furthest_point_sample_with_dist', 'PointsSampler', 'Correlation',
'boxes_iou3d', 'boxes_iou_bev', 'boxes_overlap_bev', 'nms_bev',
'nms_normal_bev', 'nms3d', 'nms3d_normal', 'Voxelization', 'voxelization',
'dynamic_scatter', 'DynamicScatter', 'RoIAwarePool3d', 'SparseConv2d',
'SparseConv3d', 'SparseConvTranspose2d', 'SparseConvTranspose3d',
'SparseInverseConv2d', 'SparseInverseConv3d', 'SubMConv2d', 'SubMConv3d',
'SparseModule', 'SparseSequential', 'SparseMaxPool2d', 'SparseMaxPool3d',
'SparseConvTensor', 'scatter_nd', 'points_in_boxes_part',
'points_in_boxes_cpu', 'points_in_boxes_all', 'points_in_polygons',
'min_area_polygons', 'active_rotated_filter', 'convex_iou', 'convex_giou',
'diff_iou_rotated_2d', 'diff_iou_rotated_3d', 'chamfer_distance',
'PrRoIPool', 'prroi_pool', 'bias_act', 'filtered_lrelu', 'conv2d',
'conv_transpose2d', 'filter2d', 'upsample2d', 'BezierAlign', 'bezier_align'
'boxes_iou_bev', 'nms_bev', 'nms_normal_bev', 'Voxelization',
'voxelization', 'dynamic_scatter', 'DynamicScatter', 'RoIAwarePool3d',
'points_in_boxes_part', 'points_in_boxes_cpu', 'points_in_boxes_all'
]
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple
import torch
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from ..utils import ext_loader
ext_module = ext_loader.load_ext(
'_ext',
['active_rotated_filter_forward', 'active_rotated_filter_backward'])
class ActiveRotatedFilterFunction(Function):
"""Encoding the orientation information and generating orientation-
sensitive features.
The details are described in the paper `Align Deep Features for Oriented
Object Detection <https://arxiv.org/abs/2008.09397>_`.
"""
@staticmethod
def forward(ctx, input: torch.Tensor,
indices: torch.Tensor) -> torch.Tensor:
"""
Args:
input (torch.Tensor): Input features with shape
[num_output_planes, num_input_planes, num_orientations, H, W].
indices (torch.Tensor): Indices with shape
[num_orientations, H, W, num_rotations].
Returns:
torch.Tensor: Refined features with shape [num_output_planes *
num_rotations, num_input_planes * num_orientations, H, W].
"""
ctx.save_for_backward(input, indices)
op, ip, o, h, w = input.size()
o, h, w, r = indices.size()
output = input.new_zeros((op * r, ip * o, h, w))
ext_module.active_rotated_filter_forward(input, indices, output)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, None]:
"""
Args:
grad_output (torch.Tensor): The gradient of output features
with shape [num_output_planes * num_rotations,
num_input_planes * num_orientations, H, W].
Returns:
torch.Tensor: The gradient of input features with shape
[num_output_planes, num_input_planes, num_orientations, H, W].
"""
input, indices = ctx.saved_tensors
grad_in = torch.zeros_like(input)
ext_module.active_rotated_filter_backward(grad_out, indices, grad_in)
return grad_in, None
active_rotated_filter = ActiveRotatedFilterFunction.apply
from typing import Tuple
import torch
from torch.autograd import Function
from ..utils import ext_loader
......@@ -30,11 +27,11 @@ class AssignScoreWithK(Function):
@staticmethod
def forward(ctx,
scores: torch.Tensor,
point_features: torch.Tensor,
center_features: torch.Tensor,
knn_idx: torch.Tensor,
aggregate: str = 'sum') -> torch.Tensor:
scores,
point_features,
center_features,
knn_idx,
aggregate='sum'):
"""
Args:
scores (torch.Tensor): (B, npoint, K, M), predicted scores to
......@@ -81,20 +78,15 @@ class AssignScoreWithK(Function):
return output
@staticmethod
def backward(
ctx, grad_out: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, None, None]:
def backward(ctx, grad_out):
"""
Args:
grad_out (torch.Tensor): (B, out_dim, npoint, K)
Returns:
tuple[torch.Tensor]: A tuple contains five elements. The first one
is the gradient of ``scores`` whose shape is (B, npoint, K, M). The
second is the gradient of ``point_features`` whose shape is
(B, N, M, out_dim). The third is the gradient of
``center_features`` with the shape of (B, N, M, out_dim). The last
two are ``None``.
grad_scores (torch.Tensor): (B, npoint, K, M)
grad_point_features (torch.Tensor): (B, N, M, out_dim)
grad_center_features (torch.Tensor): (B, N, M, out_dim)
"""
_, point_features, center_features, scores, knn_idx = ctx.saved_tensors
......
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Tuple
import torch
from torch.autograd import Function
from ..utils import ext_loader
ext_module = ext_loader.load_ext(
'_ext', ['ball_query_forward', 'stack_ball_query_forward'])
ext_module = ext_loader.load_ext('_ext', ['ball_query_forward'])
class BallQuery(Function):
"""Find nearby points in spherical space."""
@staticmethod
def forward(
ctx,
min_radius: float,
max_radius: float,
sample_num: int,
xyz: torch.Tensor,
center_xyz: torch.Tensor,
xyz_batch_cnt: Optional[torch.Tensor] = None,
center_xyz_batch_cnt: Optional[torch.Tensor] = None
) -> torch.Tensor:
def forward(ctx, min_radius: float, max_radius: float, sample_num: int,
xyz: torch.Tensor, center_xyz: torch.Tensor) -> torch.Tensor:
"""
Args:
min_radius (float): minimum radius of the balls.
max_radius (float): maximum radius of the balls.
sample_num (int): maximum number of features in the balls.
xyz (torch.Tensor): (B, N, 3) xyz coordinates of the features,
or staked input (N1 + N2 ..., 3).
center_xyz (torch.Tensor): (B, npoint, 3) centers of the ball
query, or staked input (M1 + M2 ..., 3).
xyz_batch_cnt: (batch_size): Stacked input xyz coordinates nums in
each batch, just like (N1, N2, ...). Defaults to None.
New in version 1.7.0.
center_xyz_batch_cnt: (batch_size): Stacked centers coordinates
nums in each batch, just line (M1, M2, ...). Defaults to None.
New in version 1.7.0.
xyz (Tensor): (B, N, 3) xyz coordinates of the features.
center_xyz (Tensor): (B, npoint, 3) centers of the ball query.
Returns:
torch.Tensor: (B, npoint, nsample) tensor with the indices of the
features that form the query balls.
Tensor: (B, npoint, nsample) tensor with the indices of
the features that form the query balls.
"""
assert center_xyz.is_contiguous()
assert xyz.is_contiguous()
assert min_radius < max_radius
if xyz_batch_cnt is not None and center_xyz_batch_cnt is not None:
assert xyz_batch_cnt.dtype == torch.int
assert center_xyz_batch_cnt.dtype == torch.int
idx = center_xyz.new_zeros((center_xyz.shape[0], sample_num),
dtype=torch.int32)
ext_module.stack_ball_query_forward(
center_xyz,
center_xyz_batch_cnt,
xyz,
xyz_batch_cnt,
idx,
max_radius=max_radius,
nsample=sample_num,
)
else:
B, N, _ = xyz.size()
npoint = center_xyz.size(1)
idx = xyz.new_zeros(B, npoint, sample_num, dtype=torch.int32)
ext_module.ball_query_forward(
center_xyz,
xyz,
idx,
b=B,
n=N,
m=npoint,
min_radius=min_radius,
max_radius=max_radius,
nsample=sample_num)
B, N, _ = xyz.size()
npoint = center_xyz.size(1)
idx = xyz.new_zeros(B, npoint, sample_num, dtype=torch.int)
ext_module.ball_query_forward(
center_xyz,
xyz,
idx,
b=B,
n=N,
m=npoint,
min_radius=min_radius,
max_radius=max_radius,
nsample=sample_num)
if torch.__version__ != 'parrots':
ctx.mark_non_differentiable(idx)
return idx
@staticmethod
def backward(ctx, a=None) -> Tuple[None, None, None, None]:
def backward(ctx, a=None):
return None, None, None, None
......
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from ..utils import ext_loader
ext_module = ext_loader.load_ext('_ext', ['bbox_overlaps'])
def _bbox_overlaps_cpu(bboxes1: torch.Tensor,
bboxes2: torch.Tensor,
mode: str = 'iou',
aligned: bool = False,
offset: int = 0) -> torch.Tensor:
assert mode in ['iou', 'iof']
if aligned:
lt = torch.max(bboxes1[:, :2], bboxes2[:, :2]) # [rows, 2]
rb = torch.min(bboxes1[:, 2:], bboxes2[:, 2:]) # [rows, 2]
wh = (rb - lt + offset).clamp(min=0) # [rows, 2]
overlap = wh[:, 0] * wh[:, 1]
area1 = (bboxes1[:, 2] - bboxes1[:, 0] + offset) * (
bboxes1[:, 3] - bboxes1[:, 1] + offset)
if mode == 'iou':
area2 = (bboxes2[:, 2] - bboxes2[:, 0] + offset) * (
bboxes2[:, 3] - bboxes2[:, 1] + offset)
ious = overlap / (area1 + area2 - overlap)
else:
ious = overlap / area1
else:
lt = torch.max(bboxes1[:, None, :2], bboxes2[:, :2]) # [rows, cols, 2]
rb = torch.min(bboxes1[:, None, 2:], bboxes2[:, 2:]) # [rows, cols, 2]
wh = (rb - lt + offset).clamp(min=0) # [rows, cols, 2]
overlap = wh[:, :, 0] * wh[:, :, 1]
area1 = (bboxes1[:, 2] - bboxes1[:, 0] + offset) * (
bboxes1[:, 3] - bboxes1[:, 1] + offset)
if mode == 'iou':
area2 = (bboxes2[:, 2] - bboxes2[:, 0] + offset) * (
bboxes2[:, 3] - bboxes2[:, 1] + offset)
ious = overlap / (area1[:, None] + area2 - overlap)
else:
ious = overlap / (area1[:, None])
return ious
def bbox_overlaps(bboxes1: torch.Tensor,
bboxes2: torch.Tensor,
mode: str = 'iou',
aligned: bool = False,
offset: int = 0) -> torch.Tensor:
def bbox_overlaps(bboxes1, bboxes2, mode='iou', aligned=False, offset=0):
"""Calculate overlap between two set of bboxes.
If ``aligned`` is ``False``, then calculate the ious between each bbox
......@@ -59,16 +12,14 @@ def bbox_overlaps(bboxes1: torch.Tensor,
bboxes1 and bboxes2.
Args:
bboxes1 (torch.Tensor): shape (m, 4) in <x1, y1, x2, y2> format or
empty.
bboxes2 (torch.Tensor): shape (n, 4) in <x1, y1, x2, y2> format or
empty. If aligned is ``True``, then m and n must be equal.
bboxes1 (Tensor): shape (m, 4) in <x1, y1, x2, y2> format or empty.
bboxes2 (Tensor): shape (n, 4) in <x1, y1, x2, y2> format or empty.
If aligned is ``True``, then m and n must be equal.
mode (str): "iou" (intersection over union) or iof (intersection over
foreground).
Returns:
torch.Tensor: Return the ious betweens boxes. If ``aligned`` is
``False``, the shape of ious is (m, n) else (m, 1).
ious(Tensor): shape (m, n) if aligned == False else shape (m, 1)
Example:
>>> bboxes1 = torch.FloatTensor([
......@@ -106,17 +57,16 @@ def bbox_overlaps(bboxes1: torch.Tensor,
rows = bboxes1.size(0)
cols = bboxes2.size(0)
if aligned:
assert rows == cols
ious = bboxes1.new_zeros(rows)
else:
ious = bboxes1.new_zeros((rows, cols))
if rows * cols == 0:
return ious
return bboxes1.new(rows, 1) if aligned else bboxes1.new(rows, cols)
if aligned:
ious = bboxes1.new_zeros(rows)
else:
ious = bboxes1.new_zeros((rows, cols))
ext_module.bbox_overlaps(
bboxes1, bboxes2, ious, mode=mode_flag, aligned=aligned, offset=offset)
return ious
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple, Union
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair
from ..utils import ext_loader
ext_module = ext_loader.load_ext(
'_ext', ['bezier_align_forward', 'bezier_align_backward'])
class BezierAlignFunction(Function):
@staticmethod
def forward(ctx,
input: torch.Tensor,
beziers: torch.Tensor,
output_size: Union[int, Tuple[int, int]],
spatial_scale: Union[int, float] = 1.0,
sampling_ratio: int = 0,
aligned: bool = True) -> torch.Tensor:
ctx.output_size = _pair(output_size)
ctx.spatial_scale = spatial_scale
ctx.input_shape = input.size()
ctx.sampling_ratio = sampling_ratio
ctx.aligned = aligned
assert beziers.size(1) == 17
output_shape = (beziers.size(0), input.size(1), ctx.output_size[0],
ctx.output_size[1])
output = input.new_zeros(output_shape)
ext_module.bezier_align_forward(
input,
beziers,
output,
aligned_height=ctx.output_size[0],
aligned_width=ctx.output_size[1],
spatial_scale=ctx.spatial_scale,
sampling_ratio=ctx.sampling_ratio,
aligned=ctx.aligned)
ctx.save_for_backward(beziers)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output: torch.Tensor):
beziers = ctx.saved_tensors[0]
grad_input = grad_output.new_zeros(ctx.input_shape)
grad_output = grad_output.contiguous()
ext_module.bezier_align_backward(
grad_output,
beziers,
grad_input,
aligned_height=ctx.output_size[0],
aligned_width=ctx.output_size[1],
spatial_scale=ctx.spatial_scale,
sampling_ratio=ctx.sampling_ratio,
aligned=ctx.aligned)
return grad_input, None, None, None, None, None
bezier_align = BezierAlignFunction.apply
class BezierAlign(nn.Module):
"""Bezier align pooling layer.
Args:
output_size (tuple): h, w
spatial_scale (float): scale the input boxes by this number
sampling_ratio (int): number of inputs samples to take for each
output sample. 0 to take samples densely for current models.
aligned (bool): if False, use the legacy implementation in
MMDetection. If True, align the results more perfectly.
Note:
The implementation of BezierAlign is modified from
https://github.com/aim-uofa/AdelaiDet
The meaning of aligned=True:
Given a continuous coordinate c, its two neighboring pixel
indices (in our pixel model) are computed by floor(c - 0.5) and
ceil(c - 0.5). For example, c=1.3 has pixel neighbors with discrete
indices [0] and [1] (which are sampled from the underlying signal
at continuous coordinates 0.5 and 1.5). But the original roi_align
(aligned=False) does not subtract the 0.5 when computing
neighboring pixel indices and therefore it uses pixels with a
slightly incorrect alignment (relative to our pixel model) when
performing bilinear interpolation.
With `aligned=True`,
we first appropriately scale the ROI and then shift it by -0.5
prior to calling roi_align. This produces the correct neighbors;
The difference does not make a difference to the model's
performance if ROIAlign is used together with conv layers.
"""
def __init__(
self,
output_size: Tuple,
spatial_scale: Union[int, float],
sampling_ratio: int,
aligned: bool = True,
) -> None:
super().__init__()
self.output_size = _pair(output_size)
self.spatial_scale = float(spatial_scale)
self.sampling_ratio = int(sampling_ratio)
self.aligned = aligned
def forward(self, input: torch.Tensor,
beziers: torch.Tensor) -> torch.Tensor:
"""BezierAlign forward.
Args:
inputs (Tensor): input features.
beziers (Tensor): beziers for align.
"""
return bezier_align(input, beziers, self.output_size,
self.spatial_scale, self.sampling_ratio,
self.aligned)
def __repr__(self):
s = self.__class__.__name__
s += f'(output_size={self.output_size}, '
s += f'spatial_scale={self.spatial_scale})'
s += f'sampling_ratio={self.sampling_ratio})'
s += f'aligned={self.aligned})'
return s
# Modified from
# https://github.com/NVlabs/stylegan3/blob/main/torch_utils/ops/bias_act.py
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
# source: https://github.com/open-mmlab/mmediting/blob/dev-1.x/mmedit/models/editors/stylegan3/stylegan3_ops/ops/bias_act.py # noqa
"""Custom PyTorch ops for efficient bias and activation."""
from typing import Any, Dict, Optional, Union
import numpy as np
import torch
from ..utils import ext_loader
ext_module = ext_loader.load_ext('_ext', ['bias_act'])
class EasyDict(dict):
"""Convenience class that behaves like a dict but allows access with the
attribute syntax."""
def __getattr__(self, name: str) -> Any:
try:
return self[name]
except KeyError:
raise AttributeError(name)
def __setattr__(self, name: str, value: Any) -> None:
self[name] = value
def __delattr__(self, name: str) -> None:
del self[name]
activation_funcs = {
'linear':
EasyDict(
func=lambda x, **_: x,
def_alpha=0,
def_gain=1,
cuda_idx=1,
ref='',
has_2nd_grad=False),
'relu':
EasyDict(
func=lambda x, **_: torch.nn.functional.relu(x),
def_alpha=0,
def_gain=np.sqrt(2),
cuda_idx=2,
ref='y',
has_2nd_grad=False),
'lrelu':
EasyDict(
func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha),
def_alpha=0.2,
def_gain=np.sqrt(2),
cuda_idx=3,
ref='y',
has_2nd_grad=False),
'tanh':
EasyDict(
func=lambda x, **_: torch.tanh(x),
def_alpha=0,
def_gain=1,
cuda_idx=4,
ref='y',
has_2nd_grad=True),
'sigmoid':
EasyDict(
func=lambda x, **_: torch.sigmoid(x),
def_alpha=0,
def_gain=1,
cuda_idx=5,
ref='y',
has_2nd_grad=True),
'elu':
EasyDict(
func=lambda x, **_: torch.nn.functional.elu(x),
def_alpha=0,
def_gain=1,
cuda_idx=6,
ref='y',
has_2nd_grad=True),
'selu':
EasyDict(
func=lambda x, **_: torch.nn.functional.selu(x),
def_alpha=0,
def_gain=1,
cuda_idx=7,
ref='y',
has_2nd_grad=True),
'softplus':
EasyDict(
func=lambda x, **_: torch.nn.functional.softplus(x),
def_alpha=0,
def_gain=1,
cuda_idx=8,
ref='y',
has_2nd_grad=True),
'swish':
EasyDict(
func=lambda x, **_: torch.sigmoid(x) * x,
def_alpha=0,
def_gain=np.sqrt(2),
cuda_idx=9,
ref='x',
has_2nd_grad=True),
}
_null_tensor = torch.empty([0])
def bias_act(input: torch.Tensor,
bias: Optional[torch.Tensor] = None,
dim: int = 1,
act: str = 'linear',
alpha: Optional[Union[float, int]] = None,
gain: Optional[float] = None,
clamp: Optional[float] = None,
use_custom_op: bool = True):
r"""Fused bias and activation function.
Adds `bias` to activation tensor `input`, and evaluates activation
function `act`, and scales the result by `gain`. Each of the steps is
optional.
In most cases, the fused op is considerably more efficient than performing
the same calculation using standard PyTorch ops. It supports first and
second order gradients, but not third order gradients.
Args:
input (torch.Tensor): Input activation tensor. Can be of any shape.
bias (torch.Tensor): Bias vector, or `None` to disable.
Must be a 1D tensor of the same type as `input`. The shape must
be known, and it must match the dimension of `input` corresponding
to `dim`. Defaults to None.
dim (int): The dimension in `input` corresponding to the elements of
`bias`. The value of `dim` is ignored if `b` is not specified.
Defaults to 1.
act (str): Name of the activation function to evaluate, or `"linear"`
to disable. Can be e.g. "relu", "lrelu", "tanh", "sigmoid",
"swish", etc. See `activation_funcs` for a full list. `None` is not
allowed. Defaults to `linear`.
alpha (float or int): Shape parameter for the activation
function, or `None` to use the default. Defaults to None.
gain (float): Scaling factor for the output tensor, or `None`
to use default. See `activation_funcs` for the default scaling of
each activation function. If unsure, consider specifying 1.
Defaults to None.
clamp (float): Clamp the output values to `[-clamp, +clamp]`,
or `None` to disable the clamping (default). Defaults to None.
use_custom_op (bool): Whether to use customized op.
Defaults to True.
Returns:
torch.Tensor: Tensor of the same shape and datatype as `input`.
"""
assert isinstance(input, torch.Tensor)
if use_custom_op and input.is_cuda:
return _bias_act_cuda(
dim=dim, act=act, alpha=alpha, gain=gain,
clamp=clamp).apply(input, bias)
return _bias_act_ref(
input=input,
bias=bias,
dim=dim,
act=act,
alpha=alpha,
gain=gain,
clamp=clamp)
def _bias_act_ref(input: torch.Tensor,
bias: Optional[torch.Tensor] = None,
dim: int = 1,
act: str = 'linear',
alpha: Optional[Union[float, int]] = None,
gain: Optional[float] = None,
clamp: Optional[float] = None):
"""Slow reference implementation of `bias_act()` using standard PyTorch
ops.
Adds `bias` to activation tensor `input`, and evaluates activation
function `act`, and scales the result by `gain`. Each of the steps is
optional.
In most cases, the fused op is considerably more efficient than performing
the same calculation using standard PyTorch ops. It supports first and
second order gradients, but not third order gradients.
Args:
input (torch.Tensor): Input activation tensor. Can be of any shape.
bias (torch.Tensor): Bias vector, or `None` to disable.
Must be a 1D tensor of the same type as `input`. The shape must
be known, and it must match the dimension of `input` corresponding
to `dim`. Defaults to None.
dim (int): The dimension in `input` corresponding to the elements of
`bias`. The value of `dim` is ignored if `b` is not specified.
Defaults to 1.
act (str): Name of the activation function to evaluate, or `"linear"`
to disable. Can be e.g. "relu", "lrelu", "tanh", "sigmoid",
"swish", etc. See `activation_funcs` for a full list. `None` is not
allowed. Defaults to `linear`.
alpha (float or int): Shape parameter for the activation
function, or `None` to use the default. Defaults to None.
gain (float): Scaling factor for the output tensor, or `None`
to use default. See `activation_funcs` for the default scaling of
each activation function. If unsure, consider specifying 1.
Defaults to None.
clamp (float): Clamp the output values to
`[-clamp, +clamp]`, or `None` to disable the clamping (default).
Defaults to None.
Returns:
torch.Tensor: Tensor of the same shape and datatype as `input`.
"""
assert isinstance(input, torch.Tensor)
assert clamp is None or clamp >= 0
spec = activation_funcs[act]
alpha = float(alpha if alpha is not None else spec.def_alpha)
gain = float(gain if gain is not None else spec.def_gain)
clamp = float(clamp if clamp is not None else -1)
# Add bias.
if bias is not None:
assert isinstance(bias, torch.Tensor) and bias.ndim == 1
assert 0 <= dim < input.ndim
assert bias.shape[0] == input.shape[dim]
input = input + bias.reshape(
[-1 if i == dim else 1 for i in range(input.ndim)])
# Evaluate activation function.
alpha = float(alpha)
output = spec.func(input, alpha=alpha)
# Scale by gain.
gain = float(gain)
if gain != 1:
output = output * gain
# Clamp.
if clamp >= 0:
# pylint: disable=invalid-unary-operand-type
output = output.clamp(-clamp, clamp)
return output
_bias_act_cuda_cache: Dict = dict()
def _bias_act_cuda(dim: int = 1,
act: str = 'linear',
alpha: Optional[Union[float, int]] = None,
gain: Optional[float] = None,
clamp: Optional[float] = None):
""""Fast CUDA implementation of `bias_act()` using custom ops.
Args:
dim (int): The dimension in `x` corresponding to the elements of `b`.
The value of `dim` is ignored if `b` is not specified.
Defaults to 1.
act (str): Name of the activation function to evaluate, or `"linear"`
to disable. Can be e.g. "relu", "lrelu", "tanh", "sigmoid",
"swish", etc. See `activation_funcs` for a full list. `None` is not
allowed. Defaults to `linear`.
alpha (float | int): Shape parameter for the activation
function, or `None` to use the default. Defaults to None.
gain (float): Scaling factor for the output tensor, or `None`
to use default. See `activation_funcs` for the default scaling of
each activation function. If unsure, consider specifying 1.
Defaults to None.
clamp (float): Clamp the output values to `[-clamp, +clamp]`,
or `None` to disable the clamping (default). Defaults to None.
Returns:
torch.Tensor: Tensor of the same shape and datatype as `x`.
"""
# Parse arguments.
assert clamp is None or clamp >= 0
spec = activation_funcs[act]
alpha = float(alpha if alpha is not None else spec.def_alpha)
gain = float(gain if gain is not None else spec.def_gain)
clamp = float(clamp if clamp is not None else -1)
# Lookup from cache.
key = (dim, act, alpha, gain, clamp)
if key in _bias_act_cuda_cache:
return _bias_act_cuda_cache[key]
# Forward op.
class BiasActCuda(torch.autograd.Function):
@staticmethod
def forward(ctx, x, b): # pylint: disable=arguments-differ
ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride(
1) == 1 else torch.contiguous_format
x = x.contiguous(memory_format=ctx.memory_format)
b = b.contiguous() if b is not None else _null_tensor.to(x.device)
y = x
if act != 'linear' or gain != 1 or clamp >= 0 or (
b is not _null_tensor.to(x.device)):
y = ext_module.bias_act(x, b, _null_tensor.to(x.device),
_null_tensor.to(x.device),
_null_tensor.to(x.device), 0, dim,
spec.cuda_idx, alpha, gain, clamp)
ctx.save_for_backward(
x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor.to(
x.device), b if 'x' in spec.ref or spec.has_2nd_grad else
_null_tensor.to(x.device),
y if 'y' in spec.ref else _null_tensor.to(x.device))
return y
@staticmethod
def backward(ctx, dy): # pylint: disable=arguments-differ
dy = dy.contiguous(memory_format=ctx.memory_format)
x, b, y = ctx.saved_tensors
dx = None
db = None
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
dx = dy
if act != 'linear' or gain != 1 or clamp >= 0:
dx = BiasActCudaGrad.apply(dy, x, b, y)
if ctx.needs_input_grad[1]:
db = dx.sum([i for i in range(dx.ndim) if i != dim])
return dx, db
# Backward op.
class BiasActCudaGrad(torch.autograd.Function):
@staticmethod
def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
ctx.memory_format = torch.channels_last if dy.ndim > 2 and (
dy.stride(1) == 1) else torch.contiguous_format
dx = ext_module.bias_act(dy, b, x, y, _null_tensor.to(x.device), 1,
dim, spec.cuda_idx, alpha, gain, clamp)
ctx.save_for_backward(
dy if spec.has_2nd_grad else _null_tensor.to(x.device), x, b,
y)
return dx
@staticmethod
def backward(ctx, d_dx): # pylint: disable=arguments-differ
d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
dy, x, b, y = ctx.saved_tensors
d_dy = None
d_x = None
d_b = None
d_y = None
if ctx.needs_input_grad[0]:
d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
if spec.has_2nd_grad and (ctx.needs_input_grad[1]
or ctx.needs_input_grad[2]):
d_x = ext_module.bias_act(d_dx, b, x, y, dy, 2, dim,
spec.cuda_idx, alpha, gain, clamp)
if spec.has_2nd_grad and ctx.needs_input_grad[2]:
d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
return d_dy, d_x, d_b, d_y
# Add to cache.
_bias_act_cuda_cache[key] = BiasActCuda
return BiasActCuda
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment