Unverified Commit 94cc99d5 authored by Song Lin's avatar Song Lin Committed by GitHub
Browse files

[Enhancement] Add type hints for mmcv/arraymisc and mmcv/video (#1950)



* Add type hints

* Add type hints

* Fix int float about scalar

* Add type hints for mmcv/tensorrt

* Update mmcv/tensorrt/tensorrt_utils.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmcv/arraymisc/quantization.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Ignore type hint for dtype
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
parent 882cab77
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Union
import numpy as np import numpy as np
def quantize(arr, min_val, max_val, levels, dtype=np.int64): def quantize(arr: np.ndarray,
min_val: Union[int, float],
max_val: Union[int, float],
levels: int,
dtype=np.int64) -> tuple:
"""Quantize an array of (-inf, inf) to [0, levels-1]. """Quantize an array of (-inf, inf) to [0, levels-1].
Args: Args:
arr (ndarray): Input array. arr (ndarray): Input array.
min_val (scalar): Minimum value to be clipped. min_val (int or float): Minimum value to be clipped.
max_val (scalar): Maximum value to be clipped. max_val (int or float): Maximum value to be clipped.
levels (int): Quantization levels. levels (int): Quantization levels.
dtype (np.type): The type of the quantized array. dtype (np.type): The type of the quantized array.
...@@ -29,13 +35,17 @@ def quantize(arr, min_val, max_val, levels, dtype=np.int64): ...@@ -29,13 +35,17 @@ def quantize(arr, min_val, max_val, levels, dtype=np.int64):
return quantized_arr return quantized_arr
def dequantize(arr, min_val, max_val, levels, dtype=np.float64): def dequantize(arr: np.ndarray,
min_val: Union[int, float],
max_val: Union[int, float],
levels: int,
dtype=np.float64) -> tuple:
"""Dequantize an array. """Dequantize an array.
Args: Args:
arr (ndarray): Input array. arr (ndarray): Input array.
min_val (scalar): Minimum value to be clipped. min_val (int or float): Minimum value to be clipped.
max_val (scalar): Maximum value to be clipped. max_val (int or float): Maximum value to be clipped.
levels (int): Quantization levels. levels (int): Quantization levels.
dtype (np.type): The type of the dequantized array. dtype (np.type): The type of the dequantized array.
......
...@@ -5,7 +5,7 @@ import os ...@@ -5,7 +5,7 @@ import os
import warnings import warnings
def get_tensorrt_op_path(): def get_tensorrt_op_path() -> str:
"""Get TensorRT plugins library path.""" """Get TensorRT plugins library path."""
# Following strings of text style are from colorama package # Following strings of text style are from colorama package
bright_style, reset_style = '\x1b[1m', '\x1b[0m' bright_style, reset_style = '\x1b[1m', '\x1b[0m'
...@@ -31,7 +31,7 @@ def get_tensorrt_op_path(): ...@@ -31,7 +31,7 @@ def get_tensorrt_op_path():
plugin_is_loaded = False plugin_is_loaded = False
def is_tensorrt_plugin_loaded(): def is_tensorrt_plugin_loaded() -> bool:
"""Check if TensorRT plugins library is loaded or not. """Check if TensorRT plugins library is loaded or not.
Returns: Returns:
...@@ -54,7 +54,7 @@ def is_tensorrt_plugin_loaded(): ...@@ -54,7 +54,7 @@ def is_tensorrt_plugin_loaded():
return plugin_is_loaded return plugin_is_loaded
def load_tensorrt_plugin(): def load_tensorrt_plugin() -> None:
"""load TensorRT plugins library.""" """load TensorRT plugins library."""
# Following strings of text style are from colorama package # Following strings of text style are from colorama package
......
...@@ -5,7 +5,7 @@ import numpy as np ...@@ -5,7 +5,7 @@ import numpy as np
import onnx import onnx
def preprocess_onnx(onnx_model): def preprocess_onnx(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
"""Modify onnx model to match with TensorRT plugins in mmcv. """Modify onnx model to match with TensorRT plugins in mmcv.
There are some conflict between onnx node definition and TensorRT limit. There are some conflict between onnx node definition and TensorRT limit.
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import warnings import warnings
from typing import Union
import onnx import onnx
import tensorrt as trt import tensorrt as trt
...@@ -8,12 +9,12 @@ import torch ...@@ -8,12 +9,12 @@ import torch
from .preprocess import preprocess_onnx from .preprocess import preprocess_onnx
def onnx2trt(onnx_model, def onnx2trt(onnx_model: Union[str, onnx.ModelProto],
opt_shape_dict, opt_shape_dict: dict,
log_level=trt.Logger.ERROR, log_level: trt.ILogger.Severity = trt.Logger.ERROR,
fp16_mode=False, fp16_mode: bool = False,
max_workspace_size=0, max_workspace_size: int = 0,
device_id=0): device_id: int = 0) -> trt.ICudaEngine:
"""Convert onnx model to tensorrt engine. """Convert onnx model to tensorrt engine.
Arguments: Arguments:
...@@ -100,7 +101,7 @@ def onnx2trt(onnx_model, ...@@ -100,7 +101,7 @@ def onnx2trt(onnx_model,
return engine return engine
def save_trt_engine(engine, path): def save_trt_engine(engine: trt.ICudaEngine, path: str) -> None:
"""Serialize TensorRT engine to disk. """Serialize TensorRT engine to disk.
Arguments: Arguments:
...@@ -124,7 +125,7 @@ def save_trt_engine(engine, path): ...@@ -124,7 +125,7 @@ def save_trt_engine(engine, path):
f.write(bytearray(engine.serialize())) f.write(bytearray(engine.serialize()))
def load_trt_engine(path): def load_trt_engine(path: str) -> trt.ICudaEngine:
"""Deserialize TensorRT engine from disk. """Deserialize TensorRT engine from disk.
Arguments: Arguments:
...@@ -153,7 +154,7 @@ def load_trt_engine(path): ...@@ -153,7 +154,7 @@ def load_trt_engine(path):
return engine return engine
def torch_dtype_from_trt(dtype): def torch_dtype_from_trt(dtype: trt.DataType) -> Union[torch.dtype, TypeError]:
"""Convert pytorch dtype to TensorRT dtype.""" """Convert pytorch dtype to TensorRT dtype."""
if dtype == trt.bool: if dtype == trt.bool:
return torch.bool return torch.bool
...@@ -169,7 +170,8 @@ def torch_dtype_from_trt(dtype): ...@@ -169,7 +170,8 @@ def torch_dtype_from_trt(dtype):
raise TypeError('%s is not supported by torch' % dtype) raise TypeError('%s is not supported by torch' % dtype)
def torch_device_from_trt(device): def torch_device_from_trt(
device: trt.TensorLocation) -> Union[torch.device, TypeError]:
"""Convert pytorch device to TensorRT device.""" """Convert pytorch device to TensorRT device."""
if device == trt.TensorLocation.DEVICE: if device == trt.TensorLocation.DEVICE:
return torch.device('cuda') return torch.device('cuda')
......
...@@ -272,14 +272,14 @@ class VideoReader: ...@@ -272,14 +272,14 @@ class VideoReader:
self._vcap.release() self._vcap.release()
def frames2video(frame_dir, def frames2video(frame_dir: str,
video_file, video_file: str,
fps=30, fps: float = 30,
fourcc='XVID', fourcc: str = 'XVID',
filename_tmpl='{:06d}.jpg', filename_tmpl: str = '{:06d}.jpg',
start=0, start: int = 0,
end=0, end: int = 0,
show_progress=True): show_progress: bool = True) -> None:
"""Read the frame images from a directory and join them as a video. """Read the frame images from a directory and join them as a video.
Args: Args:
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import warnings import warnings
from typing import Tuple, Union
import cv2 import cv2
import numpy as np import numpy as np
...@@ -9,7 +10,11 @@ from mmcv.image import imread, imwrite ...@@ -9,7 +10,11 @@ from mmcv.image import imread, imwrite
from mmcv.utils import is_str from mmcv.utils import is_str
def flowread(flow_or_path, quantize=False, concat_axis=0, *args, **kwargs): def flowread(flow_or_path: Union[np.ndarray, str],
quantize: bool = False,
concat_axis: int = 0,
*args,
**kwargs) -> np.ndarray:
"""Read an optical flow map. """Read an optical flow map.
Args: Args:
...@@ -58,7 +63,12 @@ def flowread(flow_or_path, quantize=False, concat_axis=0, *args, **kwargs): ...@@ -58,7 +63,12 @@ def flowread(flow_or_path, quantize=False, concat_axis=0, *args, **kwargs):
return flow.astype(np.float32) return flow.astype(np.float32)
def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs): def flowwrite(flow: np.ndarray,
filename: str,
quantize: bool = False,
concat_axis: int = 0,
*args,
**kwargs) -> None:
"""Write optical flow to file. """Write optical flow to file.
If the flow is not quantized, it will be saved as a .flo file losslessly, If the flow is not quantized, it will be saved as a .flo file losslessly,
...@@ -88,7 +98,9 @@ def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs): ...@@ -88,7 +98,9 @@ def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs):
imwrite(dxdy, filename) imwrite(dxdy, filename)
def quantize_flow(flow, max_val=0.02, norm=True): def quantize_flow(flow: np.ndarray,
max_val: float = 0.02,
norm: bool = True) -> tuple:
"""Quantize flow to [0, 255]. """Quantize flow to [0, 255].
After this step, the size of flow will be much smaller, and can be After this step, the size of flow will be much smaller, and can be
...@@ -116,7 +128,10 @@ def quantize_flow(flow, max_val=0.02, norm=True): ...@@ -116,7 +128,10 @@ def quantize_flow(flow, max_val=0.02, norm=True):
return tuple(flow_comps) return tuple(flow_comps)
def dequantize_flow(dx, dy, max_val=0.02, denorm=True): def dequantize_flow(dx: np.ndarray,
dy: np.ndarray,
max_val: float = 0.02,
denorm: bool = True) -> np.ndarray:
"""Recover from quantized flow. """Recover from quantized flow.
Args: Args:
...@@ -140,12 +155,15 @@ def dequantize_flow(dx, dy, max_val=0.02, denorm=True): ...@@ -140,12 +155,15 @@ def dequantize_flow(dx, dy, max_val=0.02, denorm=True):
return flow return flow
def flow_warp(img, flow, filling_value=0, interpolate_mode='nearest'): def flow_warp(img: np.ndarray,
flow: np.ndarray,
filling_value: int = 0,
interpolate_mode: str = 'nearest') -> np.ndarray:
"""Use flow to warp img. """Use flow to warp img.
Args: Args:
img (ndarray, float or uint8): Image to be warped. img (ndarray): Image to be warped.
flow (ndarray, float): Optical Flow. flow (ndarray): Optical Flow.
filling_value (int): The missing pixels will be set with filling_value. filling_value (int): The missing pixels will be set with filling_value.
interpolate_mode (str): bilinear -> Bilinear Interpolation; interpolate_mode (str): bilinear -> Bilinear Interpolation;
nearest -> Nearest Neighbor. nearest -> Nearest Neighbor.
...@@ -201,7 +219,7 @@ def flow_warp(img, flow, filling_value=0, interpolate_mode='nearest'): ...@@ -201,7 +219,7 @@ def flow_warp(img, flow, filling_value=0, interpolate_mode='nearest'):
return output.astype(img.dtype) return output.astype(img.dtype)
def flow_from_bytes(content): def flow_from_bytes(content: bytes) -> np.ndarray:
"""Read dense optical flow from bytes. """Read dense optical flow from bytes.
.. note:: .. note::
...@@ -231,7 +249,7 @@ def flow_from_bytes(content): ...@@ -231,7 +249,7 @@ def flow_from_bytes(content):
return flow return flow
def sparse_flow_from_bytes(content): def sparse_flow_from_bytes(content: bytes) -> Tuple[np.ndarray, np.ndarray]:
"""Read the optical flow in KITTI datasets from bytes. """Read the optical flow in KITTI datasets from bytes.
This function is modified from RAFT load the `KITTI datasets This function is modified from RAFT load the `KITTI datasets
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment