Commit ba888088 authored by zhangwenwei's avatar zhangwenwei
Browse files

Fix training runtime bug

parent 21599119
...@@ -194,7 +194,7 @@ log_config = dict( ...@@ -194,7 +194,7 @@ log_config = dict(
# yapf:enable # yapf:enable
# runtime settings # runtime settings
total_epochs = 80 total_epochs = 80
dist_params = dict(backend='nccl', port=29511) dist_params = dict(backend='nccl')
log_level = 'INFO' log_level = 'INFO'
work_dir = './work_dirs/sec_secfpn_80e' work_dir = './work_dirs/sec_secfpn_80e'
load_from = None load_from = None
......
...@@ -2,11 +2,11 @@ import torch ...@@ -2,11 +2,11 @@ import torch
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import DistSamplerSeedHook, Runner from mmcv.runner import DistSamplerSeedHook, Runner
from mmdet3d.utils import get_root_logger
from mmdet.apis.train import parse_losses from mmdet.apis.train import parse_losses
from mmdet.core import (DistEvalHook, DistOptimizerHook, EvalHook, from mmdet.core import (DistEvalHook, DistOptimizerHook, EvalHook,
Fp16OptimizerHook, build_optimizer) Fp16OptimizerHook, build_optimizer)
from mmdet.datasets import build_dataloader, build_dataset from mmdet.datasets import build_dataloader, build_dataset
from mmdet.utils import get_root_logger
def batch_processor(model, data, train_mode): def batch_processor(model, data, train_mode):
......
from mmcv.utils import build_from_cfg from mmcv.utils import build_from_cfg
from mmdet3d.utils import get_root_logger
from mmdet.core.optimizer import OPTIMIZER_BUILDERS, OPTIMIZERS from mmdet.core.optimizer import OPTIMIZER_BUILDERS, OPTIMIZERS
from mmdet.utils import get_root_logger
from .cocktail_optimizer import CocktailOptimizer from .cocktail_optimizer import CocktailOptimizer
......
...@@ -68,7 +68,7 @@ class DataBaseSampler(object): ...@@ -68,7 +68,7 @@ class DataBaseSampler(object):
db_infos = pickle.load(f) db_infos = pickle.load(f)
# filter database infos # filter database infos
from mmdet.apis import get_root_logger from mmdet3d.utils import get_root_logger
logger = get_root_logger() logger = get_root_logger()
for k, v in db_infos.items(): for k, v in db_infos.items():
logger.info(f'load {len(v)} {k} database infos') logger.info(f'load {len(v)} {k} database infos')
......
...@@ -72,7 +72,7 @@ class SECOND(nn.Module): ...@@ -72,7 +72,7 @@ class SECOND(nn.Module):
def init_weights(self, pretrained=None): def init_weights(self, pretrained=None):
if isinstance(pretrained, str): if isinstance(pretrained, str):
from mmdet3d.apis import get_root_logger from mmdet3d.utils import get_root_logger
logger = get_root_logger() logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger) load_checkpoint(self, pretrained, strict=False, logger=logger)
......
...@@ -59,7 +59,7 @@ class BaseDetector(nn.Module, metaclass=ABCMeta): ...@@ -59,7 +59,7 @@ class BaseDetector(nn.Module, metaclass=ABCMeta):
def init_weights(self, pretrained=None): def init_weights(self, pretrained=None):
if pretrained is not None: if pretrained is not None:
from mmdet3d.apis import get_root_logger from mmdet3d.utils import get_root_logger
logger = get_root_logger() logger = get_root_logger()
logger.info('load model from: {}'.format(pretrained)) logger.info('load model from: {}'.format(pretrained))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment