Commit 2c9129d4 authored by zhangwenwei's avatar zhangwenwei
Browse files

Fix optimizer

parent b77a77d4
import torch
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import DistSamplerSeedHook, Runner
from mmcv.runner import DistSamplerSeedHook, Runner, build_optimizer
from mmdet3d.utils import get_root_logger
from mmdet.apis.train import parse_losses
from mmdet.core import (DistEvalHook, DistOptimizerHook, EvalHook,
Fp16OptimizerHook, build_optimizer)
Fp16OptimizerHook)
from mmdet.datasets import build_dataloader, build_dataset
......
from mmcv.runner.optimizer import OPTIMIZER_BUILDERS, OPTIMIZERS
from mmcv.utils import build_from_cfg
from mmdet3d.utils import get_root_logger
from mmdet.core.optimizer import OPTIMIZER_BUILDERS, OPTIMIZERS
from .cocktail_optimizer import CocktailOptimizer
......
from mmcv.runner.optimizer import OPTIMIZERS
from torch.optim import Optimizer
from mmdet.core.optimizer import OPTIMIZERS
@OPTIMIZERS.register_module()
class CocktailOptimizer(Optimizer):
......@@ -9,6 +8,11 @@ class CocktailOptimizer(Optimizer):
This optimizer applies the cocktail optimzation for multi-modality models.
Args:
optimizers (list[:obj:`torch.optim.Optimizer`]): The list containing
different optimizers that optimize different parameters
step_intervals (list[int]): Step intervals of each optimizer
"""
def __init__(self, optimizers, step_intervals=None):
......@@ -18,6 +22,9 @@ class CocktailOptimizer(Optimizer):
self.param_groups += optimizer.param_groups
if not isinstance(step_intervals, list):
step_intervals = [1] * len(self.optimizers)
assert len(step_intervals) == len(optimizers), \
'"step_intervals" should contain the same number of intervals as' \
f'len(optimizers)={len(optimizers)}, got {step_intervals}'
self.step_intervals = step_intervals
self.num_step_updated = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment