"vscode:/vscode.git/clone" did not exist on "18ead1935557a2d11cac44bb5dfd82f3d63ea682"
Unverified Commit 56e9f73c authored by Luting Wang's avatar Luting Wang Committed by GitHub
Browse files

[Feature] Add ModuleDict inherited from BaseModule and ModuleDict (#1542)

* improve: add module dict

* Update __init__.py

* docs: add module dict zh

* docs: add module dict

* Update test_basemodule.py

* Update __init__.py

* lint

* lint

* docs

* docs

* Update base_module.py

* lint
parent 72182747
../../CONTRIBUTING.md ../../CONTRIBUTING.md
\ No newline at end of file
...@@ -366,7 +366,7 @@ Let us introduce the usage of `initialize` in detail. ...@@ -366,7 +366,7 @@ Let us introduce the usage of `initialize` in detail.
initialize(model, init_cfg) initialize(model, init_cfg)
``` ```
4. Initialize model inherited from BaseModule, Sequential, ModuleList 4. Initialize model inherited from BaseModule, Sequential, ModuleList, ModuleDict
`BaseModule` is inherited from `torch.nn.Module`, and the only different between them is that `BaseModule` implements `init_weight`. `BaseModule` is inherited from `torch.nn.Module`, and the only different between them is that `BaseModule` implements `init_weight`.
...@@ -374,9 +374,11 @@ Let us introduce the usage of `initialize` in detail. ...@@ -374,9 +374,11 @@ Let us introduce the usage of `initialize` in detail.
`ModuleList` is inherited from `BaseModule` and `torch.nn.ModuleList`. `ModuleList` is inherited from `BaseModule` and `torch.nn.ModuleList`.
`ModuleDict` is inherited from `BaseModule` and `torch.nn.ModuleDict`.
`````python `````python
import torch.nn as nn import torch.nn as nn
from mmcv.runner import BaseModule, Sequential, ModuleList from mmcv.runner import BaseModule, Sequential, ModuleList, ModuleDict
class FooConv1d(BaseModule): class FooConv1d(BaseModule):
...@@ -494,6 +496,49 @@ Let us introduce the usage of `initialize` in detail. ...@@ -494,6 +496,49 @@ Let us introduce the usage of `initialize` in detail.
# [[2., 2., 2.], # [[2., 2., 2.],
# [2., 2., 2.], # [2., 2., 2.],
# [2., 2., 2.]]]], requires_grad=True) # [2., 2., 2.]]]], requires_grad=True)
# ModuleDict
model1 = FooConv1d(init_cfg1)
model2 = FooConv2d(init_cfg2)
modeldict = ModuleDict(dict(model1=model1, model2=model2))
modeldict.init_weights()
# modeldict['model1'].conv1d.weight
# Parameter containing:
# tensor([[[0., 0., 0., 0.],
# [0., 0., 0., 0.],
# [0., 0., 0., 0.],
# [0., 0., 0., 0.]]], requires_grad=True)
# modeldict['model2'].conv2d.weight
# Parameter containing:
# tensor([[[[2., 2., 2.],
# [2., 2., 2.],
# [2., 2., 2.]],
# ...,
# [[2., 2., 2.],
# [2., 2., 2.],
# [2., 2., 2.]]]], requires_grad=True)
# inner init_cfg has higher priority
model1 = FooConv1d(init_cfg1)
model2 = FooConv2d(init_cfg2)
init_cfg = dict(type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.)
modeldict = ModuleDict(dict(model1=model1, model2=model2), init_cfg=init_cfg)
modeldict.init_weights()
# modeldict['model1'].conv1d.weight
# Parameter containing:
# tensor([[[0., 0., 0., 0.],
# [0., 0., 0., 0.],
# [0., 0., 0., 0.],
# [0., 0., 0., 0.]]], requires_grad=True)
# modeldict['model2'].conv2d.weight
# Parameter containing:
# tensor([[[[2., 2., 2.],
# [2., 2., 2.],
# [2., 2., 2.]],
# ...,
# [[2., 2., 2.],
# [2., 2., 2.],
# [2., 2., 2.]]]], requires_grad=True)
````` `````
### Model Zoo ### Model Zoo
......
...@@ -355,7 +355,7 @@ conv = ConvModule( ...@@ -355,7 +355,7 @@ conv = ConvModule(
initialize(model, init_cfg) initialize(model, init_cfg)
``` ```
4. 初始化继承自BaseModule、Sequential、ModuleList的模型 4. 初始化继承自BaseModule、Sequential、ModuleList、ModuleDict的模型
`BaseModule` 继承自 `torch.nn.Module`, 它们之间唯一的不同是 `BaseModule` 实现了 `init_weight` `BaseModule` 继承自 `torch.nn.Module`, 它们之间唯一的不同是 `BaseModule` 实现了 `init_weight`
...@@ -363,9 +363,11 @@ conv = ConvModule( ...@@ -363,9 +363,11 @@ conv = ConvModule(
`ModuleList` 继承自 `BaseModule` 和 `torch.nn.ModuleList` `ModuleList` 继承自 `BaseModule` 和 `torch.nn.ModuleList`
`ModuleDict` 继承自 `BaseModule` 和 `torch.nn.ModuleDict`
`````python `````python
import torch.nn as nn import torch.nn as nn
from mmcv.runner import BaseModule, Sequential, ModuleList from mmcv.runner import BaseModule, Sequential, ModuleList, ModuleDict
class FooConv1d(BaseModule): class FooConv1d(BaseModule):
...@@ -483,6 +485,49 @@ conv = ConvModule( ...@@ -483,6 +485,49 @@ conv = ConvModule(
# [[2., 2., 2.], # [[2., 2., 2.],
# [2., 2., 2.], # [2., 2., 2.],
# [2., 2., 2.]]]], requires_grad=True) # [2., 2., 2.]]]], requires_grad=True)
# ModuleDict
model1 = FooConv1d(init_cfg1)
model2 = FooConv2d(init_cfg2)
modeldict = ModuleDict(dict(model1=model1, model2=model2))
modeldict.init_weights()
# modeldict['model1'].conv1d.weight
# Parameter containing:
# tensor([[[0., 0., 0., 0.],
# [0., 0., 0., 0.],
# [0., 0., 0., 0.],
# [0., 0., 0., 0.]]], requires_grad=True)
# modeldict['model2'].conv2d.weight
# Parameter containing:
# tensor([[[[2., 2., 2.],
# [2., 2., 2.],
# [2., 2., 2.]],
# ...,
# [[2., 2., 2.],
# [2., 2., 2.],
# [2., 2., 2.]]]], requires_grad=True)
# inner init_cfg has higher priority
model1 = FooConv1d(init_cfg1)
model2 = FooConv2d(init_cfg2)
init_cfg = dict(type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.)
modeldict = ModuleDict(dict(model1=model1, model2=model2), init_cfg=init_cfg)
modeldict.init_weights()
# modeldict['model1'].conv1d.weight
# Parameter containing:
# tensor([[[0., 0., 0., 0.],
# [0., 0., 0., 0.],
# [0., 0., 0., 0.],
# [0., 0., 0., 0.]]], requires_grad=True)
# modeldict['model2'].conv2d.weight
# Parameter containing:
# tensor([[[[2., 2., 2.],
# [2., 2., 2.],
# [2., 2., 2.]],
# ...,
# [[2., 2., 2.],
# [2., 2., 2.],
# [2., 2., 2.]]]], requires_grad=True)
````` `````
### Model Zoo ### Model Zoo
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .base_module import BaseModule, ModuleList, Sequential from .base_module import BaseModule, ModuleDict, ModuleList, Sequential
from .base_runner import BaseRunner from .base_runner import BaseRunner
from .builder import RUNNERS, build_runner from .builder import RUNNERS, build_runner
from .checkpoint import (CheckpointLoader, _load_checkpoint, from .checkpoint import (CheckpointLoader, _load_checkpoint,
...@@ -42,6 +42,6 @@ __all__ = [ ...@@ -42,6 +42,6 @@ __all__ = [
'SyncBuffersHook', 'EMAHook', 'build_runner', 'RUNNERS', 'allreduce_grads', 'SyncBuffersHook', 'EMAHook', 'build_runner', 'RUNNERS', 'allreduce_grads',
'allreduce_params', 'LossScaler', 'CheckpointLoader', 'BaseModule', 'allreduce_params', 'LossScaler', 'CheckpointLoader', 'BaseModule',
'_load_checkpoint_with_prefix', 'EvalHook', 'DistEvalHook', 'Sequential', '_load_checkpoint_with_prefix', 'EvalHook', 'DistEvalHook', 'Sequential',
'ModuleList', 'GradientCumulativeOptimizerHook', 'ModuleDict', 'ModuleList', 'GradientCumulativeOptimizerHook',
'GradientCumulativeFp16OptimizerHook', 'DefaultRunnerConstructor' 'GradientCumulativeFp16OptimizerHook', 'DefaultRunnerConstructor'
] ]
...@@ -193,3 +193,17 @@ class ModuleList(BaseModule, nn.ModuleList): ...@@ -193,3 +193,17 @@ class ModuleList(BaseModule, nn.ModuleList):
def __init__(self, modules=None, init_cfg=None): def __init__(self, modules=None, init_cfg=None):
BaseModule.__init__(self, init_cfg) BaseModule.__init__(self, init_cfg)
nn.ModuleList.__init__(self, modules) nn.ModuleList.__init__(self, modules)
class ModuleDict(BaseModule, nn.ModuleDict):
"""ModuleDict in openmmlab.
Args:
modules (dict, optional): a mapping (dictionary) of (string: module)
or an iterable of key-value pairs of type (string, module).
init_cfg (dict, optional): Initialization config dict.
"""
def __init__(self, modules=None, init_cfg=None):
BaseModule.__init__(self, init_cfg)
nn.ModuleDict.__init__(self, modules)
...@@ -6,7 +6,7 @@ from torch import nn ...@@ -6,7 +6,7 @@ from torch import nn
import mmcv import mmcv
from mmcv.cnn.utils.weight_init import update_init_info from mmcv.cnn.utils.weight_init import update_init_info
from mmcv.runner import BaseModule, ModuleList, Sequential from mmcv.runner import BaseModule, ModuleDict, ModuleList, Sequential
from mmcv.utils import Registry, build_from_cfg from mmcv.utils import Registry, build_from_cfg
COMPONENTS = Registry('component') COMPONENTS = Registry('component')
...@@ -555,3 +555,54 @@ def test_modulelist_weight_init(): ...@@ -555,3 +555,54 @@ def test_modulelist_weight_init():
torch.full(modellist[1].conv2d.weight.shape, 2.)) torch.full(modellist[1].conv2d.weight.shape, 2.))
assert torch.equal(modellist[1].conv2d.bias, assert torch.equal(modellist[1].conv2d.bias,
torch.full(modellist[1].conv2d.bias.shape, 3.)) torch.full(modellist[1].conv2d.bias.shape, 3.))
def test_moduledict_weight_init():
models_cfg = dict(
foo_conv_1d=dict(
type='FooConv1d',
init_cfg=dict(type='Constant', layer='Conv1d', val=0., bias=1.)),
foo_conv_2d=dict(
type='FooConv2d',
init_cfg=dict(type='Constant', layer='Conv2d', val=2., bias=3.)),
)
layers = {
name: build_from_cfg(cfg, COMPONENTS)
for name, cfg in models_cfg.items()
}
modeldict = ModuleDict(layers)
modeldict.init_weights()
assert torch.equal(
modeldict['foo_conv_1d'].conv1d.weight,
torch.full(modeldict['foo_conv_1d'].conv1d.weight.shape, 0.))
assert torch.equal(
modeldict['foo_conv_1d'].conv1d.bias,
torch.full(modeldict['foo_conv_1d'].conv1d.bias.shape, 1.))
assert torch.equal(
modeldict['foo_conv_2d'].conv2d.weight,
torch.full(modeldict['foo_conv_2d'].conv2d.weight.shape, 2.))
assert torch.equal(
modeldict['foo_conv_2d'].conv2d.bias,
torch.full(modeldict['foo_conv_2d'].conv2d.bias.shape, 3.))
# inner init_cfg has higher priority
layers = {
name: build_from_cfg(cfg, COMPONENTS)
for name, cfg in models_cfg.items()
}
modeldict = ModuleDict(
layers,
init_cfg=dict(
type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.))
modeldict.init_weights()
assert torch.equal(
modeldict['foo_conv_1d'].conv1d.weight,
torch.full(modeldict['foo_conv_1d'].conv1d.weight.shape, 0.))
assert torch.equal(
modeldict['foo_conv_1d'].conv1d.bias,
torch.full(modeldict['foo_conv_1d'].conv1d.bias.shape, 1.))
assert torch.equal(
modeldict['foo_conv_2d'].conv2d.weight,
torch.full(modeldict['foo_conv_2d'].conv2d.weight.shape, 2.))
assert torch.equal(
modeldict['foo_conv_2d'].conv2d.bias,
torch.full(modeldict['foo_conv_2d'].conv2d.bias.shape, 3.))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment