Unverified Commit 5947178e authored by Zaida Zhou's avatar Zaida Zhou Committed by GitHub
Browse files

Remove many functions in utils and migrate them to mmengine (#2217)

* Remove runner, parallel, engine and device

* fix format

* remove outdated docs

* migrate many functions to mmengine

* remove sync_bn.py
parent 9185eee8
import argparse
import json
import os
import os.path as osp
import time
import warnings
from collections import OrderedDict
from unittest.mock import patch
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
from torch.optim import SGD
from torch.utils.data import DataLoader
import mmcv
from mmcv.runner import build_runner
from mmcv.utils import get_logger
def parse_args():
parser = argparse.ArgumentParser(description='Visualize the given config'
'of learning rate and momentum, and this'
'script will overwrite the log_config')
parser.add_argument('config', help='train config file path')
parser.add_argument(
'--work-dir', default='./', help='the dir to save logs and models')
parser.add_argument(
'--num-iters', default=300, help='The number of iters per epoch')
parser.add_argument(
'--num-epochs', default=300, help='Only used in EpochBasedRunner')
parser.add_argument(
'--window-size',
default='12*14',
help='Size of the window to display images, in format of "$W*$H".')
parser.add_argument(
'--log-interval', default=10, help='The interval of TextLoggerHook')
args = parser.parse_args()
return args
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(1, 1, 1)
def train_step(self, *args, **kwargs):
return dict()
def val_step(self, *args, **kwargs):
return dict()
def iter_train(self, data_loader, **kwargs):
self.mode = 'train'
self.data_loader = data_loader
self.call_hook('before_train_iter')
self.call_hook('after_train_iter')
self._inner_iter += 1
self._iter += 1
def epoch_train(self, data_loader, **kwargs):
self.model.train()
self.mode = 'train'
self.data_loader = data_loader
self._max_iters = self._max_epochs * len(self.data_loader)
self.call_hook('before_train_epoch')
for i, data_batch in enumerate(self.data_loader):
self._inner_iter = i
self.call_hook('before_train_iter')
self.call_hook('after_train_iter')
self._iter += 1
self.call_hook('after_train_epoch')
self._epoch += 1
def log(self, runner):
cur_iter = self.get_iter(runner, inner_iter=True)
log_dict = OrderedDict(
mode=self.get_mode(runner),
epoch=self.get_epoch(runner),
iter=cur_iter)
# only record lr of the first param group
cur_lr = runner.current_lr()
if isinstance(cur_lr, list):
log_dict['lr'] = cur_lr[0]
else:
assert isinstance(cur_lr, dict)
log_dict['lr'] = {}
for k, lr_ in cur_lr.items():
assert isinstance(lr_, list)
log_dict['lr'].update({k: lr_[0]})
cur_momentum = runner.current_momentum()
if isinstance(cur_momentum, list):
log_dict['momentum'] = cur_momentum[0]
else:
assert isinstance(cur_momentum, dict)
log_dict['momentum'] = {}
for k, lr_ in cur_momentum.items():
assert isinstance(lr_, list)
log_dict['momentum'].update({k: lr_[0]})
log_dict = dict(log_dict, **runner.log_buffer.output)
self._log_info(log_dict, runner)
self._dump_log(log_dict, runner)
return log_dict
@patch('torch.cuda.is_available', lambda: False)
@patch('mmcv.runner.EpochBasedRunner.train', epoch_train)
@patch('mmcv.runner.IterBasedRunner.train', iter_train)
@patch('mmcv.runner.hooks.TextLoggerHook.log', log)
def run(cfg, logger):
momentum_config = cfg.get('momentum_config')
lr_config = cfg.get('lr_config')
model = SimpleModel()
optimizer = SGD(model.parameters(), 0.1, momentum=0.8)
cfg.work_dir = cfg.get('work_dir', './')
workflow = [('train', 1)]
if cfg.get('runner') is None:
cfg.runner = {
'type': 'EpochBasedRunner',
'max_epochs': cfg.get('total_epochs', cfg.num_epochs)
}
warnings.warn(
'config is now expected to have a `runner` section, '
'please set `runner` in your config.', UserWarning)
batch_size = 1
data = cfg.get('data')
if data:
batch_size = data.get('samples_per_gpu')
fake_dataloader = DataLoader(
list(range(cfg.num_iters)), batch_size=batch_size)
runner = build_runner(
cfg.runner,
default_args=dict(
model=model,
batch_processor=None,
optimizer=optimizer,
work_dir=cfg.work_dir,
logger=logger,
meta=None))
log_config = dict(
interval=cfg.log_interval, hooks=[
dict(type='TextLoggerHook'),
])
runner.register_training_hooks(lr_config, log_config=log_config)
runner.register_momentum_hook(momentum_config)
runner.run([fake_dataloader], workflow)
def plot_lr_curve(json_file, cfg):
data_dict = dict(LearningRate=[], Momentum=[])
assert os.path.isfile(json_file)
with open(json_file) as f:
for line in f:
log = json.loads(line.strip())
data_dict['LearningRate'].append(log['lr'])
data_dict['Momentum'].append(log['momentum'])
wind_w, wind_h = (int(size) for size in cfg.window_size.split('*'))
# if legend is None, use {filename}_{key} as legend
fig, axes = plt.subplots(2, 1, figsize=(wind_w, wind_h))
plt.subplots_adjust(hspace=0.5)
font_size = 20
for index, (updater_type, data_list) in enumerate(data_dict.items()):
ax = axes[index]
if cfg.runner.type == 'EpochBasedRunner':
ax.plot(data_list, linewidth=1)
ax.xaxis.tick_top()
ax.set_xlabel('Iters', fontsize=font_size)
ax.xaxis.set_label_position('top')
sec_ax = ax.secondary_xaxis(
'bottom',
functions=(lambda x: x / cfg.num_iters * cfg.log_interval,
lambda y: y * cfg.num_iters / cfg.log_interval))
sec_ax.tick_params(labelsize=font_size)
sec_ax.set_xlabel('Epochs', fontsize=font_size)
else:
# plt.subplot(2, 1, index + 1)
x_list = np.arange(len(data_list)) * cfg.log_interval
ax.plot(x_list, data_list)
ax.set_xlabel('Iters', fontsize=font_size)
ax.set_ylabel(updater_type, fontsize=font_size)
if updater_type == 'LearningRate':
if cfg.get('lr_config'):
title = cfg.lr_config.type
else:
title = 'No learning rate scheduler'
else:
if cfg.get('momentum_config'):
title = cfg.momentum_config.type
else:
title = 'No momentum scheduler'
ax.set_title(title, fontsize=font_size)
ax.grid()
# set tick font size
ax.tick_params(labelsize=font_size)
save_path = osp.join(cfg.work_dir, 'visualization-result')
plt.savefig(save_path)
print(f'The learning rate graph is saved at {save_path}.png')
plt.show()
def main():
args = parse_args()
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
cfg = mmcv.Config.fromfile(args.config)
cfg['num_iters'] = args.num_iters
cfg['num_epochs'] = args.num_epochs
cfg['log_interval'] = args.log_interval
cfg['window_size'] = args.window_size
log_path = osp.join(cfg.get('work_dir', './'), f'{timestamp}.log')
json_path = log_path + '.json'
logger = get_logger('mmcv', log_path)
run(cfg, logger)
plot_lr_curve(json_path, cfg)
if __name__ == '__main__':
main()
...@@ -15,13 +15,10 @@ You can switch between Chinese and English documents in the lower-left corner of ...@@ -15,13 +15,10 @@ You can switch between Chinese and English documents in the lower-left corner of
:maxdepth: 2 :maxdepth: 2
:caption: Understand MMCV :caption: Understand MMCV
understand_mmcv/config.md
understand_mmcv/registry.md
understand_mmcv/data_process.md understand_mmcv/data_process.md
understand_mmcv/visualization.md understand_mmcv/visualization.md
understand_mmcv/cnn.md understand_mmcv/cnn.md
understand_mmcv/ops.md understand_mmcv/ops.md
understand_mmcv/utils.md
.. toctree:: .. toctree::
:maxdepth: 2 :maxdepth: 2
......
## Config
`Config` class is used for manipulating config and config files. It supports
loading configs from multiple file formats including **python**, **json** and **yaml**.
It provides dict-like apis to get and set values.
Here is an example of the config file `test.py`.
```python
a = 1
b = dict(b1=[0, 1, 2], b2=None)
c = (1, 2)
d = 'string'
```
To load and use configs
```python
>>> cfg = Config.fromfile('test.py')
>>> print(cfg)
>>> dict(a=1,
... b=dict(b1=[0, 1, 2], b2=None),
... c=(1, 2),
... d='string')
```
For all format configs, some predefined variables are supported. It will convert the variable in `{{ var }}` with its real value.
Currently, it supports four predefined variables:
`{{ fileDirname }}` - the current opened file's dirname, e.g. /home/your-username/your-project/folder
`{{ fileBasename }}` - the current opened file's basename, e.g. file.ext
`{{ fileBasenameNoExtension }}` - the current opened file's basename with no file extension, e.g. file
`{{ fileExtname }}` - the current opened file's extension, e.g. .ext
These variable names are referred from [VS Code](https://code.visualstudio.com/docs/editor/variables-reference).
Here is one examples of config with predefined variables.
`config_a.py`
```python
a = 1
b = './work_dir/{{ fileBasenameNoExtension }}'
c = '{{ fileExtname }}'
```
```python
>>> cfg = Config.fromfile('./config_a.py')
>>> print(cfg)
>>> dict(a=1,
... b='./work_dir/config_a',
... c='.py')
```
For all format configs, inheritance is supported. To reuse fields in other config files,
specify `_base_='./config_a.py'` or a list of configs `_base_=['./config_a.py', './config_b.py']`.
Here are 4 examples of config inheritance.
`config_a.py`
```python
a = 1
b = dict(b1=[0, 1, 2], b2=None)
```
### Inherit from base config without overlapped keys
`config_b.py`
```python
_base_ = './config_a.py'
c = (1, 2)
d = 'string'
```
```python
>>> cfg = Config.fromfile('./config_b.py')
>>> print(cfg)
>>> dict(a=1,
... b=dict(b1=[0, 1, 2], b2=None),
... c=(1, 2),
... d='string')
```
New fields in `config_b.py` are combined with old fields in `config_a.py`
### Inherit from base config with overlapped keys
`config_c.py`
```python
_base_ = './config_a.py'
b = dict(b2=1)
c = (1, 2)
```
```python
>>> cfg = Config.fromfile('./config_c.py')
>>> print(cfg)
>>> dict(a=1,
... b=dict(b1=[0, 1, 2], b2=1),
... c=(1, 2))
```
`b.b2=None` in `config_a` is replaced with `b.b2=1` in `config_c.py`.
### Inherit from base config with ignored fields
`config_d.py`
```python
_base_ = './config_a.py'
b = dict(_delete_=True, b2=None, b3=0.1)
c = (1, 2)
```
```python
>>> cfg = Config.fromfile('./config_d.py')
>>> print(cfg)
>>> dict(a=1,
... b=dict(b2=None, b3=0.1),
... c=(1, 2))
```
You may also set `_delete_=True` to ignore some fields in base configs. All old keys `b1, b2, b3` in `b` are replaced with new keys `b2, b3`.
### Inherit from multiple base configs (the base configs should not contain the same keys)
`config_e.py`
```python
c = (1, 2)
d = 'string'
```
`config_f.py`
```python
_base_ = ['./config_a.py', './config_e.py']
```
```python
>>> cfg = Config.fromfile('./config_f.py')
>>> print(cfg)
>>> dict(a=1,
... b=dict(b1=[0, 1, 2], b2=None),
... c=(1, 2),
... d='string')
```
### Reference variables from base
You can reference variables defined in base using the following grammar.
`base.py`
```python
item1 = 'a'
item2 = dict(item3 = 'b')
```
`config_g.py`
```python
_base_ = ['./base.py']
item = dict(a = {{ _base_.item1 }}, b = {{ _base_.item2.item3 }})
```
```python
>>> cfg = Config.fromfile('./config_g.py')
>>> print(cfg.pretty_text)
item1 = 'a'
item2 = dict(item3='b')
item = dict(a='a', b='b')
```
### Add deprecation information in configs
Deprecation information can be added in a config file, which will trigger a `UserWarning` when this config file is loaded.
`deprecated_cfg.py`
```python
_base_ = 'expected_cfg.py'
_deprecation_ = dict(
expected = 'expected_cfg.py', # optional to show expected config path in the warning information
reference = 'url to related PR' # optional to show reference link in the warning information
)
```
```python
>>> cfg = Config.fromfile('./deprecated_cfg.py')
UserWarning: The config file deprecated_cfg.py will be deprecated in the future. Please use expected_cfg.py instead. More information can be found at https://github.com/open-mmlab/mmcv/pull/1275
```
## Registry
MMCV implements [registry](https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/registry.py) to manage different modules that share similar functionalities, e.g., backbones, head, and necks, in detectors.
Most projects in OpenMMLab use registry to manage modules of datasets and models, such as [MMDetection](https://github.com/open-mmlab/mmdetection), [MMDetection3D](https://github.com/open-mmlab/mmdetection3d), [MMClassification](https://github.com/open-mmlab/mmclassification), [MMEditing](https://github.com/open-mmlab/mmediting), etc.
```{note}
In v1.5.1 and later, the Registry supports registering functions and calling them.
```
### What is registry
In MMCV, registry can be regarded as a mapping that maps a class or function to a string.
These classes or functions contained by a single registry usually have similar APIs but implement different algorithms or support different datasets.
With the registry, users can find the class or function through its corresponding string, and instantiate the corresponding module or call the function to obtain the result according to needs.
One typical example is the config systems in most OpenMMLab projects, which use the registry to create hooks, runners, models, and datasets, through configs.
The API reference could be found [here](https://mmcv.readthedocs.io/en/latest/api.html?highlight=registry#mmcv.utils.Registry).
To manage your modules in the codebase by `Registry`, there are three steps as below.
1. Create a build method (optional, in most cases you can just use the default one).
2. Create a registry.
3. Use this registry to manage the modules.
`build_func` argument of `Registry` is to customize how to instantiate the class instance or how to call the function to obtain the result, the default one is `build_from_cfg` implemented [here](https://mmcv.readthedocs.io/en/latest/api.html?highlight=registry#mmcv.utils.build_from_cfg).
### A Simple Example
Here we show a simple example of using registry to manage modules in a package.
You can find more practical examples in OpenMMLab projects.
Assuming we want to implement a series of Dataset Converter for converting different formats of data to the expected data format.
We create a directory as a package named `converters`.
In the package, we first create a file to implement builders, named `converters/builder.py`, as below
```python
from mmcv.utils import Registry
# create a registry for converters
CONVERTERS = Registry('converters')
```
Then we can implement different converters that is class or function in the package. For example, implement `Converter1` in `converters/converter1.py`, and `converter2` in `converters/converter2.py`.
```python
from .builder import CONVERTERS
# use the registry to manage the module
@CONVERTERS.register_module()
class Converter1(object):
def __init__(self, a, b):
self.a = a
self.b = b
```
```python
# converter2.py
from .builder import CONVERTERS
from .converter1 import Converter1
# 使用注册器管理模块
@CONVERTERS.register_module()
def converter2(a, b)
return Converter1(a, b)
```
The key step to use registry for managing the modules is to register the implemented module into the registry `CONVERTERS` through
`@CONVERTERS.register_module()` when you are creating the module. By this way, a mapping between a string and the class (function) is built and maintained by `CONVERTERS` as below
```python
'Converter1' -> <class 'Converter1'>
'converter2' -> <function 'converter2'>
```
```{note}
The registry mechanism will be triggered only when the file where the module is located is imported.
So you need to import that file somewhere. More details can be found at https://github.com/open-mmlab/mmdetection/issues/5974.
```
If the module is successfully registered, you can use this converter through configs as
```python
converter1_cfg = dict(type='Converter1', a=a_value, b=b_value)
converter2_cfg = dict(type='converter2', a=a_value, b=b_value)
converter1 = CONVERTERS.build(converter1_cfg)
# returns the calling result
result = CONVERTERS.build(converter2_cfg)
```
### Customize Build Function
Suppose we would like to customize how `converters` are built, we could implement a customized `build_func` and pass it into the registry.
```python
from mmcv.utils import Registry
# create a build function
def build_converter(cfg, registry, *args, **kwargs):
cfg_ = cfg.copy()
converter_type = cfg_.pop('type')
if converter_type not in registry:
raise KeyError(f'Unrecognized converter type {converter_type}')
else:
converter_cls = registry.get(converter_type)
converter = converter_cls(*args, **kwargs, **cfg_)
return converter
# create a registry for converters and pass ``build_converter`` function
CONVERTERS = Registry('converter', build_func=build_converter)
```
```{note}
In this example, we demonstrate how to use the `build_func` argument to customize the way to build a class instance.
The functionality is similar to the default `build_from_cfg`. In most cases, default one would be sufficient.
`build_model_from_cfg` is also implemented to build PyTorch module in `nn.Sequential`, you may directly use them instead of implementing by yourself.
```
### Hierarchy Registry
You could also build modules from more than one OpenMMLab frameworks, e.g. you could use all backbones in [MMClassification](https://github.com/open-mmlab/mmclassification) for object detectors in [MMDetection](https://github.com/open-mmlab/mmdetection), you may also combine an object detection model in [MMDetection](https://github.com/open-mmlab/mmdetection) and semantic segmentation model in [MMSegmentation](https://github.com/open-mmlab/mmsegmentation).
All `MODELS` registries of downstream codebases are children registries of MMCV's `MODELS` registry.
Basically, there are two ways to build a module from child or sibling registries.
1. Build from children registries.
For example:
In MMDetection we define:
```python
from mmengine.registry import Registry
from mmengine.registry import MODELS as MMENGINE_MODELS
MODELS = Registry('model', parent=MMENGINE_MODELS)
@MODELS.register_module()
class NetA(nn.Module):
def forward(self, x):
return x
```
In MMClassification we define:
```python
from mmengine.registry import Registry
from mmengine.registry import MODELS as MMENGINE_MODELS
MODELS = Registry('model', parent=MMENGINE_MODELS)
@MODELS.register_module()
class NetB(nn.Module):
def forward(self, x):
return x + 1
```
We could build two net in either MMDetection or MMClassification by:
```python
from mmdet.models import MODELS
net_a = MODELS.build(cfg=dict(type='NetA'))
net_b = MODELS.build(cfg=dict(type='mmcls.NetB'))
```
or
```python
from mmcls.models import MODELS
net_a = MODELS.build(cfg=dict(type='mmdet.NetA'))
net_b = MODELS.build(cfg=dict(type='NetB'))
```
2. Build from parent registry.
The shared `MODELS` registry in MMCV is the parent registry for all downstream codebases (root registry):
```python
from mmengine.registry import MODELS as MMENGINE_MODELS
net_a = MMENGINE_MODELS.build(cfg=dict(type='mmdet.NetA'))
net_b = MMENGINE_MODELS.build(cfg=dict(type='mmcls.NetB'))
```
## Utils
### ProgressBar
If you want to apply a method to a list of items and track the progress, `track_progress`
is a good choice. It will display a progress bar to tell the progress and ETA.
```python
import mmcv
def func(item):
# do something
pass
tasks = [item_1, item_2, ..., item_n]
mmcv.track_progress(func, tasks)
```
The output is like the following.
![progress](../_static/progress.*)
There is another method `track_parallel_progress`, which wraps multiprocessing and
progress visualization.
```python
mmcv.track_parallel_progress(func, tasks, 8) # 8 workers
```
![progress](../_static/parallel_progress.*)
If you want to iterate or enumerate a list of items and track the progress, `track_iter_progress`
is a good choice. It will display a progress bar to tell the progress and ETA.
```python
import mmcv
tasks = [item_1, item_2, ..., item_n]
for task in mmcv.track_iter_progress(tasks):
# do something like print
print(task)
for i, task in enumerate(mmcv.track_iter_progress(tasks)):
# do something like print
print(i)
print(task)
```
### Timer
It is convenient to compute the runtime of a code block with `Timer`.
```python
import time
with mmcv.Timer():
# simulate some code block
time.sleep(1)
```
or try with `since_start()` and `since_last_check()`. This former can
return the runtime since the timer starts and the latter will return the time
since the last time checked.
```python
timer = mmcv.Timer()
# code block 1 here
print(timer.since_start())
# code block 2 here
print(timer.since_last_check())
print(timer.since_start())
```
...@@ -15,14 +15,11 @@ ...@@ -15,14 +15,11 @@
:maxdepth: 2 :maxdepth: 2
:caption: 深入理解 MMCV :caption: 深入理解 MMCV
understand_mmcv/config.md
understand_mmcv/registry.md
understand_mmcv/data_process.md understand_mmcv/data_process.md
understand_mmcv/data_transform.md understand_mmcv/data_transform.md
understand_mmcv/visualization.md understand_mmcv/visualization.md
understand_mmcv/cnn.md understand_mmcv/cnn.md
understand_mmcv/ops.md understand_mmcv/ops.md
understand_mmcv/utils.md
.. toctree:: .. toctree::
:maxdepth: 2 :maxdepth: 2
......
## 配置
`Config` 类用于操作配置文件,它支持从多种文件格式中加载配置,包括 **python**, **json****yaml**
它提供了类似字典对象的接口来获取和设置值。
以配置文件 `test.py` 为例
```python
a = 1
b = dict(b1=[0, 1, 2], b2=None)
c = (1, 2)
d = 'string'
```
加载与使用配置文件
```python
>>> cfg = Config.fromfile('test.py')
>>> print(cfg)
>>> dict(a=1,
... b=dict(b1=[0, 1, 2], b2=None),
... c=(1, 2),
... d='string')
```
对于所有格式的配置文件,都支持一些预定义变量。它会将 `{{ var }}` 替换为实际值。
目前支持以下四个预定义变量:
`{{ fileDirname }}` - 当前打开文件的目录名,例如 /home/your-username/your-project/folder
`{{ fileBasename }}` - 当前打开文件的文件名,例如 file.ext
`{{ fileBasenameNoExtension }}` - 当前打开文件不包含扩展名的文件名,例如 file
`{{ fileExtname }}` - 当前打开文件的扩展名,例如 .ext
这些变量名引用自 [VS Code](https://code.visualstudio.com/docs/editor/variables-reference)
这里是一个带有预定义变量的配置文件的例子。
`config_a.py`
```python
a = 1
b = './work_dir/{{ fileBasenameNoExtension }}'
c = '{{ fileExtname }}'
```
```python
>>> cfg = Config.fromfile('./config_a.py')
>>> print(cfg)
>>> dict(a=1,
... b='./work_dir/config_a',
... c='.py')
```
对于所有格式的配置文件, 都支持继承。为了重用其他配置文件的字段,
需要指定 `_base_='./config_a.py'` 或者一个包含配置文件的列表 `_base_=['./config_a.py', './config_b.py']`
这里有 4 个配置继承关系的例子。
`config_a.py` 作为基类配置文件
```python
a = 1
b = dict(b1=[0, 1, 2], b2=None)
```
### 不含重复键值对从基类配置文件继承
`config_b.py`
```python
_base_ = './config_a.py'
c = (1, 2)
d = 'string'
```
```python
>>> cfg = Config.fromfile('./config_b.py')
>>> print(cfg)
>>> dict(a=1,
... b=dict(b1=[0, 1, 2], b2=None),
... c=(1, 2),
... d='string')
```
`config_b.py`里的新字段与在`config_a.py`里的旧字段拼接
### 含重复键值对从基类配置文件继承
`config_c.py`
```python
_base_ = './config_a.py'
b = dict(b2=1)
c = (1, 2)
```
```python
>>> cfg = Config.fromfile('./config_c.py')
>>> print(cfg)
>>> dict(a=1,
... b=dict(b1=[0, 1, 2], b2=1),
... c=(1, 2))
```
在基类配置文件:`config_a` 里的 `b.b2=None`被配置文件:`config_c.py`里的 `b.b2=1`替代。
### 从具有忽略字段的配置文件继承
`config_d.py`
```python
_base_ = './config_a.py'
b = dict(_delete_=True, b2=None, b3=0.1)
c = (1, 2)
```
```python
>>> cfg = Config.fromfile('./config_d.py')
>>> print(cfg)
>>> dict(a=1,
... b=dict(b2=None, b3=0.1),
... c=(1, 2))
```
您还可以设置 `_delete_=True`忽略基类配置文件中的某些字段。所有在`b`中的旧键 `b1, b2, b3` 将会被新键 `b2, b3` 所取代。
### 从多个基类配置文件继承(基类配置文件不应包含相同的键)
`config_e.py`
```python
c = (1, 2)
d = 'string'
```
`config_f.py`
```python
_base_ = ['./config_a.py', './config_e.py']
```
```python
>>> cfg = Config.fromfile('./config_f.py')
>>> print(cfg)
>>> dict(a=1,
... b=dict(b1=[0, 1, 2], b2=None),
... c=(1, 2),
... d='string')
```
### 从基类引用变量
您可以使用以下语法引用在基类中定义的变量。
`base.py`
```python
item1 = 'a'
item2 = dict(item3 = 'b')
```
`config_g.py`
```python
_base_ = ['./base.py']
item = dict(a = {{ _base_.item1 }}, b = {{ _base_.item2.item3 }})
```
```python
>>> cfg = Config.fromfile('./config_g.py')
>>> print(cfg.pretty_text)
item1 = 'a'
item2 = dict(item3='b')
item = dict(a='a', b='b')
```
## 注册器
MMCV 使用 [注册器](https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/registry.py) 来管理具有相似功能的不同模块, 例如, 检测器中的主干网络、头部、和模型颈部。
在 OpenMMLab 家族中的绝大部分开源项目使用注册器去管理数据集和模型的模块,例如 [MMDetection](https://github.com/open-mmlab/mmdetection), [MMDetection3D](https://github.com/open-mmlab/mmdetection3d), [MMClassification](https://github.com/open-mmlab/mmclassification), [MMEditing](https://github.com/open-mmlab/mmediting) 等。
```{note}
在 v1.5.1 版本开始支持注册函数的功能。
```
### 什么是注册器
在MMCV中,注册器可以看作类或函数到字符串的映射。
一个注册器中的类或函数通常有相似的接口,但是可以实现不同的算法或支持不同的数据集。
借助注册器,用户可以通过使用相应的字符串查找类或函数,并根据他们的需要实例化对应模块或调用函数获取结果。
一个典型的案例是,OpenMMLab 中的大部分开源项目的配置系统,这些系统通过配置文件来使用注册器创建钩子、执行器、模型和数据集。
可以在[这里](https://mmcv.readthedocs.io/en/latest/api.html?highlight=registry#mmcv.utils.Registry)找到注册器接口使用文档。
使用 `registry`(注册器)管理代码库中的模型,需要以下三个步骤。
1. 创建一个构建方法(可选,在大多数情况下您可以只使用默认方法)
2. 创建注册器
3. 使用此注册器来管理模块
`Registry`(注册器)的参数 `build_func`(构建函数) 用来自定义如何实例化类的实例或如何调用函数获取结果,默认使用 [这里](https://mmcv.readthedocs.io/en/latest/api.html?highlight=registry#mmcv.utils.build_from_cfg) 实现的`build_from_cfg`
### 一个简单的例子
这里是一个使用注册器管理包中模块的简单示例。您可以在 OpenMMLab 开源项目中找到更多实例。
假设我们要实现一系列数据集转换器(Dataset Converter),用于将不同格式的数据转换为标准数据格式。我们先创建一个名为converters的目录作为包,在包中我们创建一个文件来实现构建器(builder),命名为converters/builder.py,如下
```python
from mmengine.registry import Registry
# 创建转换器(converter)的注册器(registry)
CONVERTERS = Registry('converter')
```
然后我们在包中可以实现不同的转换器(converter),其可以为类或函数。例如,在 `converters/converter1.py` 中实现 `Converter1`,在 `converters/converter2.py` 中实现 `converter2`
```python
# converter1.py
from .builder import CONVERTERS
# 使用注册器管理模块
@CONVERTERS.register_module()
class Converter1(object):
def __init__(self, a, b):
self.a = a
self.b = b
```
```python
# converter2.py
from .builder import CONVERTERS
from .converter1 import Converter1
# 使用注册器管理模块
@CONVERTERS.register_module()
def converter2(a, b)
return Converter1(a, b)
```
使用注册器管理模块的关键步骤是,将实现的模块注册到注册表 `CONVERTERS` 中。通过 `@CONVERTERS.register_module()` 装饰所实现的模块,字符串到类或函数之间的映射就可以由 `CONVERTERS` 构建和维护,如下所示:
通过这种方式,就可以通过 `CONVERTERS` 建立字符串与类或函数之间的映射,如下所示:
```python
'Converter1' -> <class 'Converter1'>
'converter2' -> <function 'converter2'>
```
```{note}
只有模块所在的文件被导入时,注册机制才会被触发,所以您需要在某处导入该文件。更多详情请查看 https://github.com/open-mmlab/mmdetection/issues/5974。
```
如果模块被成功注册了,你可以通过配置文件使用这个转换器(converter),如下所示:
```python
converter1_cfg = dict(type='Converter1', a=a_value, b=b_value)
converter2_cfg = dict(type='converter2', a=a_value, b=b_value)
converter1 = CONVERTERS.build(converter1_cfg)
# returns the calling result
result = CONVERTERS.build(converter2_cfg)
```
### 自定义构建函数
假设我们想自定义 `converters` 的构建流程,我们可以实现一个自定义的 `build_func` (构建函数)并将其传递到注册器中。
```python
from mmcv.utils import Registry
# 创建一个构建函数
def build_converter(cfg, registry, *args, **kwargs):
cfg_ = cfg.copy()
converter_type = cfg_.pop('type')
if converter_type not in registry:
raise KeyError(f'Unrecognized converter type {converter_type}')
else:
converter_cls = registry.get(converter_type)
converter = converter_cls(*args, **kwargs, **cfg_)
return converter
# 创建一个用于转换器(converters)的注册器,并传递(registry)``build_converter`` 函数
CONVERTERS = Registry('converter', build_func=build_converter)
```
```{note}
注:在这个例子中,我们演示了如何使用参数:`build_func` 自定义构建类的实例的方法。
该功能类似于默认的`build_from_cfg`。在大多数情况下,默认就足够了。
```
`build_model_from_cfg`也实现了在`nn.Sequential`中构建PyTorch模块,你可以直接使用它们。
### 注册器层结构
你也可以从多个 OpenMMLab 开源框架中构建模块,例如,你可以把所有 [MMClassification](https://github.com/open-mmlab/mmclassification) 中的主干网络(backbone)用到 [MMDetection](https://github.com/open-mmlab/mmdetection) 的目标检测中,你也可以融合 [MMDetection](https://github.com/open-mmlab/mmdetection) 中的目标检测模型 和 [MMSegmentation](https://github.com/open-mmlab/mmsegmentation) 语义分割模型。
下游代码库中所有 `MODELS` 注册器都是MMCV `MODELS` 注册器的子注册器。基本上,使用以下两种方法从子注册器或相邻兄弟注册器构建模块。
1. 从子注册器中构建
例如:
我们在 MMDetection 中定义:
```python
from mmengine.resgitry import Registry
from mmengine.resgitry import MODELS as MMENGINE_MODELS
MODELS = Registry('model', parent=MMENGINE_MODELS)
@MODELS.register_module()
class NetA(nn.Module):
def forward(self, x):
return x
```
我们在 MMClassification 中定义:
```python
from mmengine.registry import Registry
from mmengine.registry import MODELS as MMENGINE_MODELS
MODELS = Registry('model', parent=MMENGINE_MODELS)
@MODELS.register_module()
class NetB(nn.Module):
def forward(self, x):
return x + 1
```
我们可以通过以下代码在 MMDetection 或 MMClassification 中构建两个网络:
```python
from mmdet.models import MODELS
net_a = MODELS.build(cfg=dict(type='NetA'))
net_b = MODELS.build(cfg=dict(type='mmcls.NetB'))
```
```python
from mmcls.models import MODELS
net_a = MODELS.build(cfg=dict(type='mmdet.NetA'))
net_b = MODELS.build(cfg=dict(type='NetB'))
```
2. 从父注册器中构建
MMCV中的共享`MODELS`注册器是所有下游代码库的父注册器(根注册器):
```python
from mmengine.registry import MODELS as MMENGINE_MODELS
net_a = MMENGINE_MODELS.build(cfg=dict(type='mmdet.NetA'))
net_b = MMENGINE_MODELS.build(cfg=dict(type='mmcls.NetB'))
```
## 辅助函数
### 进度条
如果你想跟踪函数批处理任务的进度,可以使用 `track_progress` 。它能以进度条的形式展示任务的完成情况以及剩余任务所需的时间(内部实现为for循环)。
```python
import mmcv
def func(item):
# 执行相关操作
pass
tasks = [item_1, item_2, ..., item_n]
mmcv.track_progress(func, tasks)
```
效果如下
![progress](../../en/_static/progress.*)
如果你想可视化多进程任务的进度,你可以使用 `track_parallel_progress`
```python
mmcv.track_parallel_progress(func, tasks, 8) # 8 workers
```
![progress](../../_static/parallel_progress.*)
如果你想要迭代或枚举数据列表并可视化进度,你可以使用 `track_iter_progress`
```python
import mmcv
tasks = [item_1, item_2, ..., item_n]
for task in mmcv.track_iter_progress(tasks):
# do something like print
print(task)
for i, task in enumerate(mmcv.track_iter_progress(tasks)):
# do something like print
print(i)
print(task)
```
### 计时器
mmcv提供的 `Timer` 可以很方便地计算代码块的执行时间。
```python
import time
with mmcv.Timer():
# simulate some code block
time.sleep(1)
```
你也可以使用 `since_start()``since_last_check()` 。前者返回计时器启动后的运行时长,后者返回最近一次查看计时器后的运行时长。
```python
timer = mmcv.Timer()
# code block 1 here
print(timer.since_start())
# code block 2 here
print(timer.since_last_check())
print(timer.since_start())
```
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
from .arraymisc import * from .arraymisc import *
from .image import * from .image import *
from .transforms import * from .transforms import *
from .utils import *
from .version import * from .version import *
from .video import * from .video import *
from .visualization import * from .visualization import *
...@@ -11,3 +10,4 @@ from .visualization import * ...@@ -11,3 +10,4 @@ from .visualization import *
# The following modules are not imported to this level, so mmcv may be used # The following modules are not imported to this level, so mmcv may be used
# without PyTorch. # without PyTorch.
# - op # - op
# - utils
...@@ -6,8 +6,8 @@ import torch ...@@ -6,8 +6,8 @@ import torch
import torch.nn as nn import torch.nn as nn
from mmengine.model.utils import constant_init, kaiming_init from mmengine.model.utils import constant_init, kaiming_init
from mmengine.registry import MODELS from mmengine.registry import MODELS
from mmengine.utils.parrots_wrapper import _BatchNorm, _InstanceNorm
from mmcv.utils import _BatchNorm, _InstanceNorm
from .activation import build_activation_layer from .activation import build_activation_layer
from .conv import build_conv_layer from .conv import build_conv_layer
from .norm import build_norm_layer from .norm import build_norm_layer
......
...@@ -2,8 +2,7 @@ ...@@ -2,8 +2,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmengine.registry import MODELS from mmengine.registry import MODELS
from mmengine.utils import TORCH_VERSION, digit_version
from mmcv.utils import TORCH_VERSION, digit_version
class HSwish(nn.Module): class HSwish(nn.Module):
......
...@@ -4,9 +4,9 @@ from typing import Dict, Tuple, Union ...@@ -4,9 +4,9 @@ from typing import Dict, Tuple, Union
import torch.nn as nn import torch.nn as nn
from mmengine.registry import MODELS from mmengine.registry import MODELS
from mmengine.utils import is_tuple_of
from mmcv.utils import is_tuple_of from mmengine.utils.parrots_wrapper import (SyncBatchNorm, _BatchNorm,
from mmcv.utils.parrots_wrapper import SyncBatchNorm, _BatchNorm, _InstanceNorm _InstanceNorm)
MODELS.register_module('BN', module=nn.BatchNorm2d) MODELS.register_module('BN', module=nn.BatchNorm2d)
MODELS.register_module('BN1d', module=nn.BatchNorm1d) MODELS.register_module('BN1d', module=nn.BatchNorm1d)
......
...@@ -10,10 +10,10 @@ import torch.nn.functional as F ...@@ -10,10 +10,10 @@ import torch.nn.functional as F
from mmengine import ConfigDict from mmengine import ConfigDict
from mmengine.model import BaseModule, ModuleList, Sequential from mmengine.model import BaseModule, ModuleList, Sequential
from mmengine.registry import MODELS from mmengine.registry import MODELS
from mmengine.utils import deprecated_api_warning, to_2tuple
from mmcv.cnn import (Linear, build_activation_layer, build_conv_layer, from mmcv.cnn import (Linear, build_activation_layer, build_conv_layer,
build_norm_layer) build_norm_layer)
from mmcv.utils import deprecated_api_warning, to_2tuple
from .drop import build_dropout from .drop import build_dropout
# Avoid BC-breaking of importing MultiScaleDeformableAttention from this file # Avoid BC-breaking of importing MultiScaleDeformableAttention from this file
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .flops_counter import get_model_complexity_info from .flops_counter import get_model_complexity_info
from .fuse_conv_bn import fuse_conv_bn from .fuse_conv_bn import fuse_conv_bn
from .sync_bn import revert_sync_batchnorm
__all__ = [ __all__ = ['get_model_complexity_info', 'fuse_conv_bn']
'get_model_complexity_info', 'fuse_conv_bn', 'revert_sync_batchnorm'
]
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import mmcv
class _BatchNormXd(nn.modules.batchnorm._BatchNorm):
"""A general BatchNorm layer without input dimension check.
Reproduced from @kapily's work:
(https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547)
The only difference between BatchNorm1d, BatchNorm2d, BatchNorm3d, etc
is `_check_input_dim` that is designed for tensor sanity checks.
The check has been bypassed in this class for the convenience of converting
SyncBatchNorm.
"""
def _check_input_dim(self, input: torch.Tensor):
return
def revert_sync_batchnorm(module: nn.Module) -> nn.Module:
"""Helper function to convert all `SyncBatchNorm` (SyncBN) and
`mmcv.ops.sync_bn.SyncBatchNorm`(MMSyncBN) layers in the model to
`BatchNormXd` layers.
Adapted from @kapily's work:
(https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547)
Args:
module (nn.Module): The module containing `SyncBatchNorm` layers.
Returns:
module_output: The converted module with `BatchNormXd` layers.
"""
module_output = module
module_checklist = [torch.nn.modules.batchnorm.SyncBatchNorm]
if hasattr(mmcv, 'ops'):
module_checklist.append(mmcv.ops.SyncBatchNorm)
if isinstance(module, tuple(module_checklist)):
module_output = _BatchNormXd(module.num_features, module.eps,
module.momentum, module.affine,
module.track_running_stats)
if module.affine:
# no_grad() may not be needed here but
# just to be consistent with `convert_sync_batchnorm()`
with torch.no_grad():
module_output.weight = module.weight
module_output.bias = module.bias
module_output.running_mean = module.running_mean
module_output.running_var = module.running_var
module_output.num_batches_tracked = module.num_batches_tracked
module_output.training = module.training
# qconfig exists in quantized models
if hasattr(module, 'qconfig'):
module_output.qconfig = module.qconfig
for name, child in module.named_children():
module_output.add_module(name, revert_sync_batchnorm(child))
del module
return module_output
...@@ -5,8 +5,8 @@ from typing import Optional, Tuple ...@@ -5,8 +5,8 @@ from typing import Optional, Tuple
import cv2 import cv2
import numpy as np import numpy as np
from mmengine.utils import to_2tuple
from ..utils import to_2tuple
from .io import imread_backend from .io import imread_backend
try: try:
......
...@@ -9,8 +9,7 @@ import numpy as np ...@@ -9,8 +9,7 @@ import numpy as np
from cv2 import (IMREAD_COLOR, IMREAD_GRAYSCALE, IMREAD_IGNORE_ORIENTATION, from cv2 import (IMREAD_COLOR, IMREAD_GRAYSCALE, IMREAD_IGNORE_ORIENTATION,
IMREAD_UNCHANGED) IMREAD_UNCHANGED)
from mmengine.fileio import FileClient from mmengine.fileio import FileClient
from mmengine.utils import is_filepath, is_str
from mmcv.utils import is_filepath, is_str
try: try:
from turbojpeg import TJCS_RGB, TJPF_BGR, TJPF_GRAY, TurboJPEG from turbojpeg import TJCS_RGB, TJPF_BGR, TJPF_GRAY, TurboJPEG
......
...@@ -4,9 +4,9 @@ from typing import Optional ...@@ -4,9 +4,9 @@ from typing import Optional
import cv2 import cv2
import numpy as np import numpy as np
from mmengine.utils import is_tuple_of
from PIL import Image, ImageEnhance from PIL import Image, ImageEnhance
from ..utils import is_tuple_of
from .colorspace import bgr2gray, gray2bgr from .colorspace import bgr2gray, gray2bgr
from .io import imread_backend from .io import imread_backend
......
...@@ -6,12 +6,12 @@ import torch.nn as nn ...@@ -6,12 +6,12 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmengine import print_log from mmengine import print_log
from mmengine.registry import MODELS from mmengine.registry import MODELS
from mmengine.utils import deprecated_api_warning
from torch import Tensor from torch import Tensor
from torch.autograd import Function from torch.autograd import Function
from torch.autograd.function import once_differentiable from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair, _single from torch.nn.modules.utils import _pair, _single
from mmcv.utils import deprecated_api_warning
from ..utils import ext_loader from ..utils import ext_loader
ext_module = ext_loader.load_ext('_ext', [ ext_module = ext_loader.load_ext('_ext', [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment