Unverified Commit 63b7aa31 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Fix docstring formats (#383)

* update doc formats

* update docstring
parent a47451b4
......@@ -216,7 +216,7 @@ def nms_match(dets, iou_threshold):
"""Matched dets into different groups by NMS.
NMS match is Similar to NMS but when a bbox is suppressed, nms match will
record the indice of supporessed bbox and form a group with the indice of
record the indice of suppressed bbox and form a group with the indice of
kept bbox. In each group, indice is sorted as score order.
Arguments:
......@@ -224,9 +224,9 @@ def nms_match(dets, iou_threshold):
iou_thr (float): IoU thresh for NMS.
Returns:
List[Tensor | ndarray]: The outer list corresponds different matched
group, the inner Tensor corresponds the indices for a group in
score order.
List[torch.Tensor | np.ndarray]: The outer list corresponds different
matched group, the inner Tensor corresponds the indices for a group
in score order.
"""
if dets.shape[0] == 0:
matched = []
......
......@@ -134,9 +134,9 @@ def rel_roi_point_to_rel_img_point(rois,
def point_sample(input, points, align_corners=False, **kwargs):
"""A wrapper around :function:`grid_sample` to support 3D point_coords
tensors Unlike :function:`torch.nn.functional.grid_sample` it assumes
point_coords to lie inside [0, 1] x [0, 1] square.
"""A wrapper around :func:`grid_sample` to support 3D point_coords tensors
Unlike :func:`torch.nn.functional.grid_sample` it assumes point_coords to
lie inside ``[0, 1] x [0, 1]`` square.
Args:
input (Tensor): Feature map, shape (N, C, H, W).
......
......@@ -4,8 +4,7 @@ from torch.nn.parallel._functions import _get_stream
def scatter(input, devices, streams=None):
"""Scatters tensor across multiple GPUs.
"""
"""Scatters tensor across multiple GPUs."""
if streams is None:
streams = [None] * len(devices)
......
......@@ -43,7 +43,7 @@ def scatter(inputs, target_gpus, dim=0):
def scatter_kwargs(inputs, kwargs, target_gpus, dim=0):
"""Scatter with support for kwargs dictionary"""
"""Scatter with support for kwargs dictionary."""
inputs = scatter(inputs, target_gpus, dim) if inputs else []
kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
if len(inputs) < len(kwargs):
......
......@@ -27,5 +27,5 @@ __all__ = [
'obj_from_dict', 'init_dist', 'get_dist_info', 'master_only',
'OPTIMIZER_BUILDERS', 'OPTIMIZERS', 'DefaultOptimizerConstructor',
'build_optimizer', 'build_optimizer_constructor', 'IterLoader',
'IterBasedRunner', 'set_random_seed'
'set_random_seed'
]
......@@ -247,7 +247,7 @@ class BaseRunner(metaclass=ABCMeta):
"""Register a hook into the hook list.
The hook will be inserted into a priority queue, with the specified
priority (See :cls:`Priority` for details of priorities).
priority (See :class:`Priority` for details of priorities).
For hooks with the same priority, they will be triggered in the same
order as they are registered.
......
......@@ -103,8 +103,8 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
def load_url_dist(url, model_dir=None):
""" In distributed setting, this function only download checkpoint at
local rank 0 """
"""In distributed setting, this function only download checkpoint at local
rank 0."""
rank, world_size = get_dist_info()
rank = int(os.environ.get('LOCAL_RANK', rank))
if rank == 0:
......
......@@ -161,7 +161,7 @@ class EpochBasedRunner(BaseRunner):
class Runner(EpochBasedRunner):
"""Deprecated name of EpochBasedRunner"""
"""Deprecated name of EpochBasedRunner."""
def __init__(self, *args, **kwargs):
warnings.warn(
......
......@@ -5,7 +5,7 @@ from .hook import HOOKS, Hook
class LrUpdaterHook(Hook):
"""LR Scheduler in MMCV
"""LR Scheduler in MMCV.
Args:
by_epoch (bool): LR changes epoch by epoch
......@@ -325,7 +325,7 @@ def get_position_from_periods(iteration, cumulative_periods):
@HOOKS.register_module()
class CyclicLrUpdaterHook(LrUpdaterHook):
"""Cyclic LR Scheduler
"""Cyclic LR Scheduler.
Implement the cyclical learning rate policy (CLR) described in
https://arxiv.org/pdf/1506.01186.pdf
......@@ -341,7 +341,6 @@ class CyclicLrUpdaterHook(LrUpdaterHook):
step_ratio_up (float): The ratio of the increasing process of LR in
the total cycle.
by_epoch (bool): Whether to update LR by epoch.
"""
def __init__(self,
......
......@@ -128,7 +128,7 @@ class CosineAnealingMomentumUpdaterHook(MomentumUpdaterHook):
@HOOKS.register_module()
class CyclicMomentumUpdaterHook(MomentumUpdaterHook):
"""Cyclic momentum Scheduler
"""Cyclic momentum Scheduler.
Implemet the cyclical momentum scheduler policy described in
https://arxiv.org/pdf/1708.07120.pdf
......@@ -143,7 +143,6 @@ class CyclicMomentumUpdaterHook(MomentumUpdaterHook):
step_ratio_up (float): The ratio of the increasing process of momentum
in the total cycle.
by_epoch (bool): Whether to update momentum by epoch.
"""
def __init__(self,
......
......@@ -31,7 +31,7 @@ class LogBuffer:
self.n_history[key].append(count)
def average(self, n=0):
"""Average latest n values or all values"""
"""Average latest n values or all values."""
assert n >= 0
for key in self.val_history:
values = np.array(self.val_history[key][-n:])
......
......@@ -44,8 +44,11 @@ class DefaultOptimizerConstructor:
model (:obj:`nn.Module`): The model with parameters to be optimized.
optimizer_cfg (dict): The config dict of the optimizer.
Positional fields are
- `type`: class name of the optimizer.
Optional fields are
- any arguments of the corresponding optimizer type, e.g.,
lr, weight_decay, momentum, etc.
paramwise_cfg (dict, optional): Parameter-wise options.
......
......@@ -78,7 +78,6 @@ class Config:
>>> cfg
"Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: "
"{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}"
"""
@staticmethod
......@@ -180,8 +179,7 @@ class Config:
@staticmethod
def auto_argparser(description=None):
"""Generate argparser from config file automatically (experimental)
"""
"""Generate argparser from config file automatically (experimental)"""
partial_parser = ArgumentParser(description=description)
partial_parser.add_argument('config', help='config file path')
cfg_file = partial_parser.parse_known_args()[0].config
......@@ -356,7 +354,7 @@ class Config:
mmcv.dump(cfg_dict, file)
def merge_from_dict(self, options):
"""Merge list into cfg_dict
"""Merge list into cfg_dict.
Merge the dict parsed by MultipleKVAction into this cfg.
......
......@@ -8,7 +8,7 @@ from .timer import Timer
class ProgressBar:
"""A progress bar which can print the progress"""
"""A progress bar which can print the progress."""
def __init__(self, task_num=0, bar_width=50, start=True, file=sys.stdout):
self.task_num = task_num
......@@ -176,7 +176,8 @@ def track_parallel_progress(func,
def track_iter_progress(tasks, bar_width=50, file=sys.stdout):
"""Track the progress of tasks iteration or enumeration with a progress bar.
"""Track the progress of tasks iteration or enumeration with a progress
bar.
Tasks are yielded with a simple for-loop.
......
......@@ -200,7 +200,7 @@ class VideoReader:
start=0,
max_num=0,
show_progress=True):
"""Convert a video to frame images
"""Convert a video to frame images.
Args:
frame_dir (str): Output directory to store all the frame images.
......@@ -282,7 +282,7 @@ def frames2video(frame_dir,
start=0,
end=0,
show_progress=True):
"""Read the frame images from a directory and join them as a video
"""Read the frame images from a directory and join them as a video.
Args:
frame_dir (str): The directory containing video frames.
......
......@@ -139,7 +139,7 @@ def dequantize_flow(dx, dy, max_val=0.02, denorm=True):
def flow_warp(img, flow, filling_value=0, interpolate_mode='nearest'):
"""Use flow to warp img
"""Use flow to warp img.
Args:
img (ndarray, float or uint8): Image to be warped.
......
......@@ -14,9 +14,8 @@ from Cython.Distutils import build_ext as build_cmd # NOQA: E402 # isort:skip
def choose_requirement(primary, secondary):
"""If some version of primary requirement installed, return primary,
else return secondary.
"""
"""If some version of primary requirement installed, return primary, else
return secondary."""
try:
name = re.split(r'[!<>=]', primary)[0]
get_distribution(name)
......@@ -40,8 +39,7 @@ def get_version():
def parse_requirements(fname='requirements.txt', with_version=True):
"""
Parse the package dependencies listed in a requirements file but strips
"""Parse the package dependencies listed in a requirements file but strips
specific versioning information.
Args:
......@@ -60,9 +58,7 @@ def parse_requirements(fname='requirements.txt', with_version=True):
require_fpath = fname
def parse_line(line):
"""
Parse information from a line in a requirements text file
"""
"""Parse information from a line in a requirements text file."""
if line.startswith('-r '):
# Allow specifying requirements in other files
target = line.split(' ')[1]
......
"""
Tests the hooks with runners.
"""Tests the hooks with runners.
CommandLine:
pytest tests/test_hooks.py
xdoctest tests/test_hooks.py zero
"""
import logging
import os.path as osp
......@@ -49,9 +47,7 @@ def test_pavi_hook():
def test_momentum_runner_hook():
"""
xdoctest -m tests/test_hooks.py test_momentum_runner_hook
"""
"""xdoctest -m tests/test_hooks.py test_momentum_runner_hook."""
sys.modules['pavi'] = MagicMock()
loader = DataLoader(torch.ones((10, 2)))
runner = _build_demo_runner()
......@@ -99,9 +95,7 @@ def test_momentum_runner_hook():
def test_cosine_runner_hook():
"""
xdoctest -m tests/test_hooks.py test_cosine_runner_hook
"""
"""xdoctest -m tests/test_hooks.py test_cosine_runner_hook."""
sys.modules['pavi'] = MagicMock()
loader = DataLoader(torch.ones((10, 2)))
runner = _build_demo_runner()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment