Commit 4b4cc7e8 authored by zhangwenwei's avatar zhangwenwei
Browse files

Merge branch 'fix-optimizer' into 'master'

Fix optimizer

See merge request open-mmlab/mmdet.3d!58
parents b77a77d4 2c9129d4
import torch import torch
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import DistSamplerSeedHook, Runner from mmcv.runner import DistSamplerSeedHook, Runner, build_optimizer
from mmdet3d.utils import get_root_logger from mmdet3d.utils import get_root_logger
from mmdet.apis.train import parse_losses from mmdet.apis.train import parse_losses
from mmdet.core import (DistEvalHook, DistOptimizerHook, EvalHook, from mmdet.core import (DistEvalHook, DistOptimizerHook, EvalHook,
Fp16OptimizerHook, build_optimizer) Fp16OptimizerHook)
from mmdet.datasets import build_dataloader, build_dataset from mmdet.datasets import build_dataloader, build_dataset
......
from mmcv.runner.optimizer import OPTIMIZER_BUILDERS, OPTIMIZERS
from mmcv.utils import build_from_cfg from mmcv.utils import build_from_cfg
from mmdet3d.utils import get_root_logger from mmdet3d.utils import get_root_logger
from mmdet.core.optimizer import OPTIMIZER_BUILDERS, OPTIMIZERS
from .cocktail_optimizer import CocktailOptimizer from .cocktail_optimizer import CocktailOptimizer
......
from mmcv.runner.optimizer import OPTIMIZERS
from torch.optim import Optimizer from torch.optim import Optimizer
from mmdet.core.optimizer import OPTIMIZERS
@OPTIMIZERS.register_module() @OPTIMIZERS.register_module()
class CocktailOptimizer(Optimizer): class CocktailOptimizer(Optimizer):
...@@ -9,6 +8,11 @@ class CocktailOptimizer(Optimizer): ...@@ -9,6 +8,11 @@ class CocktailOptimizer(Optimizer):
This optimizer applies the cocktail optimzation for multi-modality models. This optimizer applies the cocktail optimzation for multi-modality models.
Args:
optimizers (list[:obj:`torch.optim.Optimizer`]): The list containing
different optimizers that optimize different parameters
step_intervals (list[int]): Step intervals of each optimizer
""" """
def __init__(self, optimizers, step_intervals=None): def __init__(self, optimizers, step_intervals=None):
...@@ -18,6 +22,9 @@ class CocktailOptimizer(Optimizer): ...@@ -18,6 +22,9 @@ class CocktailOptimizer(Optimizer):
self.param_groups += optimizer.param_groups self.param_groups += optimizer.param_groups
if not isinstance(step_intervals, list): if not isinstance(step_intervals, list):
step_intervals = [1] * len(self.optimizers) step_intervals = [1] * len(self.optimizers)
assert len(step_intervals) == len(optimizers), \
'"step_intervals" should contain the same number of intervals as' \
f'len(optimizers)={len(optimizers)}, got {step_intervals}'
self.step_intervals = step_intervals self.step_intervals = step_intervals
self.num_step_updated = 0 self.num_step_updated = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment