Unverified Commit 63b7aa31 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Fix docstring formats (#383)

* update doc formats

* update docstring
parent a47451b4
...@@ -216,7 +216,7 @@ def nms_match(dets, iou_threshold): ...@@ -216,7 +216,7 @@ def nms_match(dets, iou_threshold):
"""Matched dets into different groups by NMS. """Matched dets into different groups by NMS.
NMS match is Similar to NMS but when a bbox is suppressed, nms match will NMS match is Similar to NMS but when a bbox is suppressed, nms match will
record the indice of supporessed bbox and form a group with the indice of record the indice of suppressed bbox and form a group with the indice of
kept bbox. In each group, indice is sorted as score order. kept bbox. In each group, indice is sorted as score order.
Arguments: Arguments:
...@@ -224,9 +224,9 @@ def nms_match(dets, iou_threshold): ...@@ -224,9 +224,9 @@ def nms_match(dets, iou_threshold):
iou_thr (float): IoU thresh for NMS. iou_thr (float): IoU thresh for NMS.
Returns: Returns:
List[Tensor | ndarray]: The outer list corresponds different matched List[torch.Tensor | np.ndarray]: The outer list corresponds different
group, the inner Tensor corresponds the indices for a group in matched group, the inner Tensor corresponds the indices for a group
score order. in score order.
""" """
if dets.shape[0] == 0: if dets.shape[0] == 0:
matched = [] matched = []
......
...@@ -134,9 +134,9 @@ def rel_roi_point_to_rel_img_point(rois, ...@@ -134,9 +134,9 @@ def rel_roi_point_to_rel_img_point(rois,
def point_sample(input, points, align_corners=False, **kwargs): def point_sample(input, points, align_corners=False, **kwargs):
"""A wrapper around :function:`grid_sample` to support 3D point_coords """A wrapper around :func:`grid_sample` to support 3D point_coords tensors
tensors Unlike :function:`torch.nn.functional.grid_sample` it assumes Unlike :func:`torch.nn.functional.grid_sample` it assumes point_coords to
point_coords to lie inside [0, 1] x [0, 1] square. lie inside ``[0, 1] x [0, 1]`` square.
Args: Args:
input (Tensor): Feature map, shape (N, C, H, W). input (Tensor): Feature map, shape (N, C, H, W).
......
...@@ -4,8 +4,7 @@ from torch.nn.parallel._functions import _get_stream ...@@ -4,8 +4,7 @@ from torch.nn.parallel._functions import _get_stream
def scatter(input, devices, streams=None): def scatter(input, devices, streams=None):
"""Scatters tensor across multiple GPUs. """Scatters tensor across multiple GPUs."""
"""
if streams is None: if streams is None:
streams = [None] * len(devices) streams = [None] * len(devices)
......
...@@ -43,7 +43,7 @@ def scatter(inputs, target_gpus, dim=0): ...@@ -43,7 +43,7 @@ def scatter(inputs, target_gpus, dim=0):
def scatter_kwargs(inputs, kwargs, target_gpus, dim=0): def scatter_kwargs(inputs, kwargs, target_gpus, dim=0):
"""Scatter with support for kwargs dictionary""" """Scatter with support for kwargs dictionary."""
inputs = scatter(inputs, target_gpus, dim) if inputs else [] inputs = scatter(inputs, target_gpus, dim) if inputs else []
kwargs = scatter(kwargs, target_gpus, dim) if kwargs else [] kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
if len(inputs) < len(kwargs): if len(inputs) < len(kwargs):
......
...@@ -27,5 +27,5 @@ __all__ = [ ...@@ -27,5 +27,5 @@ __all__ = [
'obj_from_dict', 'init_dist', 'get_dist_info', 'master_only', 'obj_from_dict', 'init_dist', 'get_dist_info', 'master_only',
'OPTIMIZER_BUILDERS', 'OPTIMIZERS', 'DefaultOptimizerConstructor', 'OPTIMIZER_BUILDERS', 'OPTIMIZERS', 'DefaultOptimizerConstructor',
'build_optimizer', 'build_optimizer_constructor', 'IterLoader', 'build_optimizer', 'build_optimizer_constructor', 'IterLoader',
'IterBasedRunner', 'set_random_seed' 'set_random_seed'
] ]
...@@ -247,7 +247,7 @@ class BaseRunner(metaclass=ABCMeta): ...@@ -247,7 +247,7 @@ class BaseRunner(metaclass=ABCMeta):
"""Register a hook into the hook list. """Register a hook into the hook list.
The hook will be inserted into a priority queue, with the specified The hook will be inserted into a priority queue, with the specified
priority (See :cls:`Priority` for details of priorities). priority (See :class:`Priority` for details of priorities).
For hooks with the same priority, they will be triggered in the same For hooks with the same priority, they will be triggered in the same
order as they are registered. order as they are registered.
......
...@@ -103,8 +103,8 @@ def load_state_dict(module, state_dict, strict=False, logger=None): ...@@ -103,8 +103,8 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
def load_url_dist(url, model_dir=None): def load_url_dist(url, model_dir=None):
""" In distributed setting, this function only download checkpoint at """In distributed setting, this function only download checkpoint at local
local rank 0 """ rank 0."""
rank, world_size = get_dist_info() rank, world_size = get_dist_info()
rank = int(os.environ.get('LOCAL_RANK', rank)) rank = int(os.environ.get('LOCAL_RANK', rank))
if rank == 0: if rank == 0:
......
...@@ -161,7 +161,7 @@ class EpochBasedRunner(BaseRunner): ...@@ -161,7 +161,7 @@ class EpochBasedRunner(BaseRunner):
class Runner(EpochBasedRunner): class Runner(EpochBasedRunner):
"""Deprecated name of EpochBasedRunner""" """Deprecated name of EpochBasedRunner."""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
warnings.warn( warnings.warn(
......
...@@ -5,7 +5,7 @@ from .hook import HOOKS, Hook ...@@ -5,7 +5,7 @@ from .hook import HOOKS, Hook
class LrUpdaterHook(Hook): class LrUpdaterHook(Hook):
"""LR Scheduler in MMCV """LR Scheduler in MMCV.
Args: Args:
by_epoch (bool): LR changes epoch by epoch by_epoch (bool): LR changes epoch by epoch
...@@ -325,7 +325,7 @@ def get_position_from_periods(iteration, cumulative_periods): ...@@ -325,7 +325,7 @@ def get_position_from_periods(iteration, cumulative_periods):
@HOOKS.register_module() @HOOKS.register_module()
class CyclicLrUpdaterHook(LrUpdaterHook): class CyclicLrUpdaterHook(LrUpdaterHook):
"""Cyclic LR Scheduler """Cyclic LR Scheduler.
Implement the cyclical learning rate policy (CLR) described in Implement the cyclical learning rate policy (CLR) described in
https://arxiv.org/pdf/1506.01186.pdf https://arxiv.org/pdf/1506.01186.pdf
...@@ -341,7 +341,6 @@ class CyclicLrUpdaterHook(LrUpdaterHook): ...@@ -341,7 +341,6 @@ class CyclicLrUpdaterHook(LrUpdaterHook):
step_ratio_up (float): The ratio of the increasing process of LR in step_ratio_up (float): The ratio of the increasing process of LR in
the total cycle. the total cycle.
by_epoch (bool): Whether to update LR by epoch. by_epoch (bool): Whether to update LR by epoch.
""" """
def __init__(self, def __init__(self,
......
...@@ -128,7 +128,7 @@ class CosineAnealingMomentumUpdaterHook(MomentumUpdaterHook): ...@@ -128,7 +128,7 @@ class CosineAnealingMomentumUpdaterHook(MomentumUpdaterHook):
@HOOKS.register_module() @HOOKS.register_module()
class CyclicMomentumUpdaterHook(MomentumUpdaterHook): class CyclicMomentumUpdaterHook(MomentumUpdaterHook):
"""Cyclic momentum Scheduler """Cyclic momentum Scheduler.
Implemet the cyclical momentum scheduler policy described in Implemet the cyclical momentum scheduler policy described in
https://arxiv.org/pdf/1708.07120.pdf https://arxiv.org/pdf/1708.07120.pdf
...@@ -143,7 +143,6 @@ class CyclicMomentumUpdaterHook(MomentumUpdaterHook): ...@@ -143,7 +143,6 @@ class CyclicMomentumUpdaterHook(MomentumUpdaterHook):
step_ratio_up (float): The ratio of the increasing process of momentum step_ratio_up (float): The ratio of the increasing process of momentum
in the total cycle. in the total cycle.
by_epoch (bool): Whether to update momentum by epoch. by_epoch (bool): Whether to update momentum by epoch.
""" """
def __init__(self, def __init__(self,
......
...@@ -31,7 +31,7 @@ class LogBuffer: ...@@ -31,7 +31,7 @@ class LogBuffer:
self.n_history[key].append(count) self.n_history[key].append(count)
def average(self, n=0): def average(self, n=0):
"""Average latest n values or all values""" """Average latest n values or all values."""
assert n >= 0 assert n >= 0
for key in self.val_history: for key in self.val_history:
values = np.array(self.val_history[key][-n:]) values = np.array(self.val_history[key][-n:])
......
...@@ -44,8 +44,11 @@ class DefaultOptimizerConstructor: ...@@ -44,8 +44,11 @@ class DefaultOptimizerConstructor:
model (:obj:`nn.Module`): The model with parameters to be optimized. model (:obj:`nn.Module`): The model with parameters to be optimized.
optimizer_cfg (dict): The config dict of the optimizer. optimizer_cfg (dict): The config dict of the optimizer.
Positional fields are Positional fields are
- `type`: class name of the optimizer. - `type`: class name of the optimizer.
Optional fields are Optional fields are
- any arguments of the corresponding optimizer type, e.g., - any arguments of the corresponding optimizer type, e.g.,
lr, weight_decay, momentum, etc. lr, weight_decay, momentum, etc.
paramwise_cfg (dict, optional): Parameter-wise options. paramwise_cfg (dict, optional): Parameter-wise options.
......
...@@ -78,7 +78,6 @@ class Config: ...@@ -78,7 +78,6 @@ class Config:
>>> cfg >>> cfg
"Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: " "Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: "
"{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}" "{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}"
""" """
@staticmethod @staticmethod
...@@ -180,8 +179,7 @@ class Config: ...@@ -180,8 +179,7 @@ class Config:
@staticmethod @staticmethod
def auto_argparser(description=None): def auto_argparser(description=None):
"""Generate argparser from config file automatically (experimental) """Generate argparser from config file automatically (experimental)"""
"""
partial_parser = ArgumentParser(description=description) partial_parser = ArgumentParser(description=description)
partial_parser.add_argument('config', help='config file path') partial_parser.add_argument('config', help='config file path')
cfg_file = partial_parser.parse_known_args()[0].config cfg_file = partial_parser.parse_known_args()[0].config
...@@ -356,7 +354,7 @@ class Config: ...@@ -356,7 +354,7 @@ class Config:
mmcv.dump(cfg_dict, file) mmcv.dump(cfg_dict, file)
def merge_from_dict(self, options): def merge_from_dict(self, options):
"""Merge list into cfg_dict """Merge list into cfg_dict.
Merge the dict parsed by MultipleKVAction into this cfg. Merge the dict parsed by MultipleKVAction into this cfg.
......
...@@ -8,7 +8,7 @@ from .timer import Timer ...@@ -8,7 +8,7 @@ from .timer import Timer
class ProgressBar: class ProgressBar:
"""A progress bar which can print the progress""" """A progress bar which can print the progress."""
def __init__(self, task_num=0, bar_width=50, start=True, file=sys.stdout): def __init__(self, task_num=0, bar_width=50, start=True, file=sys.stdout):
self.task_num = task_num self.task_num = task_num
...@@ -176,7 +176,8 @@ def track_parallel_progress(func, ...@@ -176,7 +176,8 @@ def track_parallel_progress(func,
def track_iter_progress(tasks, bar_width=50, file=sys.stdout): def track_iter_progress(tasks, bar_width=50, file=sys.stdout):
"""Track the progress of tasks iteration or enumeration with a progress bar. """Track the progress of tasks iteration or enumeration with a progress
bar.
Tasks are yielded with a simple for-loop. Tasks are yielded with a simple for-loop.
......
...@@ -200,7 +200,7 @@ class VideoReader: ...@@ -200,7 +200,7 @@ class VideoReader:
start=0, start=0,
max_num=0, max_num=0,
show_progress=True): show_progress=True):
"""Convert a video to frame images """Convert a video to frame images.
Args: Args:
frame_dir (str): Output directory to store all the frame images. frame_dir (str): Output directory to store all the frame images.
...@@ -282,7 +282,7 @@ def frames2video(frame_dir, ...@@ -282,7 +282,7 @@ def frames2video(frame_dir,
start=0, start=0,
end=0, end=0,
show_progress=True): show_progress=True):
"""Read the frame images from a directory and join them as a video """Read the frame images from a directory and join them as a video.
Args: Args:
frame_dir (str): The directory containing video frames. frame_dir (str): The directory containing video frames.
......
...@@ -139,7 +139,7 @@ def dequantize_flow(dx, dy, max_val=0.02, denorm=True): ...@@ -139,7 +139,7 @@ def dequantize_flow(dx, dy, max_val=0.02, denorm=True):
def flow_warp(img, flow, filling_value=0, interpolate_mode='nearest'): def flow_warp(img, flow, filling_value=0, interpolate_mode='nearest'):
"""Use flow to warp img """Use flow to warp img.
Args: Args:
img (ndarray, float or uint8): Image to be warped. img (ndarray, float or uint8): Image to be warped.
......
...@@ -14,9 +14,8 @@ from Cython.Distutils import build_ext as build_cmd # NOQA: E402 # isort:skip ...@@ -14,9 +14,8 @@ from Cython.Distutils import build_ext as build_cmd # NOQA: E402 # isort:skip
def choose_requirement(primary, secondary): def choose_requirement(primary, secondary):
"""If some version of primary requirement installed, return primary, """If some version of primary requirement installed, return primary, else
else return secondary. return secondary."""
"""
try: try:
name = re.split(r'[!<>=]', primary)[0] name = re.split(r'[!<>=]', primary)[0]
get_distribution(name) get_distribution(name)
...@@ -40,8 +39,7 @@ def get_version(): ...@@ -40,8 +39,7 @@ def get_version():
def parse_requirements(fname='requirements.txt', with_version=True): def parse_requirements(fname='requirements.txt', with_version=True):
""" """Parse the package dependencies listed in a requirements file but strips
Parse the package dependencies listed in a requirements file but strips
specific versioning information. specific versioning information.
Args: Args:
...@@ -60,9 +58,7 @@ def parse_requirements(fname='requirements.txt', with_version=True): ...@@ -60,9 +58,7 @@ def parse_requirements(fname='requirements.txt', with_version=True):
require_fpath = fname require_fpath = fname
def parse_line(line): def parse_line(line):
""" """Parse information from a line in a requirements text file."""
Parse information from a line in a requirements text file
"""
if line.startswith('-r '): if line.startswith('-r '):
# Allow specifying requirements in other files # Allow specifying requirements in other files
target = line.split(' ')[1] target = line.split(' ')[1]
......
""" """Tests the hooks with runners.
Tests the hooks with runners.
CommandLine: CommandLine:
pytest tests/test_hooks.py pytest tests/test_hooks.py
xdoctest tests/test_hooks.py zero xdoctest tests/test_hooks.py zero
""" """
import logging import logging
import os.path as osp import os.path as osp
...@@ -49,9 +47,7 @@ def test_pavi_hook(): ...@@ -49,9 +47,7 @@ def test_pavi_hook():
def test_momentum_runner_hook(): def test_momentum_runner_hook():
""" """xdoctest -m tests/test_hooks.py test_momentum_runner_hook."""
xdoctest -m tests/test_hooks.py test_momentum_runner_hook
"""
sys.modules['pavi'] = MagicMock() sys.modules['pavi'] = MagicMock()
loader = DataLoader(torch.ones((10, 2))) loader = DataLoader(torch.ones((10, 2)))
runner = _build_demo_runner() runner = _build_demo_runner()
...@@ -99,9 +95,7 @@ def test_momentum_runner_hook(): ...@@ -99,9 +95,7 @@ def test_momentum_runner_hook():
def test_cosine_runner_hook(): def test_cosine_runner_hook():
""" """xdoctest -m tests/test_hooks.py test_cosine_runner_hook."""
xdoctest -m tests/test_hooks.py test_cosine_runner_hook
"""
sys.modules['pavi'] = MagicMock() sys.modules['pavi'] = MagicMock()
loader = DataLoader(torch.ones((10, 2))) loader = DataLoader(torch.ones((10, 2)))
runner = _build_demo_runner() runner = _build_demo_runner()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment