Unverified Commit 6fa8850e authored by ChaimZhu's avatar ChaimZhu Committed by GitHub
Browse files

[Fix]: fix init_model_api (#1739)

parent 4e3cfbe4
......@@ -46,10 +46,7 @@ def main(args):
# init visualizer
visualizer = VISUALIZERS.build(model.cfg.visualizer)
visualizer.dataset_meta = {
'CLASSES': model.CLASSES,
'PALETTE': model.PALETTE
}
visualizer.dataset_meta = model.dataset_meta
# test a single image
result = inference_mono_3d_detector(model, args.img, args.ann,
......
......@@ -47,10 +47,7 @@ def main(args):
# init visualizer
visualizer = VISUALIZERS.build(model.cfg.visualizer)
visualizer.dataset_meta = {
'CLASSES': model.CLASSES,
'PALETTE': model.PALETTE
}
visualizer.dataset_meta = model.dataset_meta
# test a single image and point cloud sample
result, data = inference_multi_modality_detector(model, args.pcd, args.img,
......
......@@ -36,10 +36,7 @@ def main(args):
# init visualizer
visualizer = VISUALIZERS.build(model.cfg.visualizer)
visualizer.dataset_meta = {
'CLASSES': model.CLASSES,
'PALETTE': model.PALETTE
}
visualizer.dataset_meta = model.dataset_meta
# test a single point cloud sample
result, data = inference_segmentor(model, args.pcd)
......
......@@ -39,9 +39,7 @@ def main(args):
# init visualizer
visualizer = VISUALIZERS.build(model.cfg.visualizer)
visualizer.dataset_meta = {
'CLASSES': model.CLASSES,
}
visualizer.dataset_meta = model.dataset_meta
# test a single point cloud sample
result, data = inference_detector(model, args.pcd)
......
......@@ -2,7 +2,8 @@
import warnings
from copy import deepcopy
from os import path as osp
from typing import Sequence, Union
from pathlib import Path
from typing import Optional, Sequence, Union
import mmengine
import numpy as np
......@@ -33,36 +34,60 @@ def convert_SyncBN(config):
convert_SyncBN(config[item])
def init_model(config, checkpoint=None, device='cuda:0'):
def init_model(config: Union[str, Path, Config],
checkpoint: Optional[str] = None,
device: str = 'cuda:0',
cfg_options: Optional[dict] = None):
"""Initialize a model from config file, which could be a 3D detector or a
3D segmentor.
Args:
config (str or :obj:`mmengine.Config`): Config file path or the config
object.
config (str, :obj:`Path`, or :obj:`mmengine.Config`): Config file path,
:obj:`Path`, or the config object.
checkpoint (str, optional): Checkpoint path. If left as None, the model
will not load any weights.
device (str): Device to use.
cfg_options (dict, optional): Options to override some settings in
the used config.
Returns:
nn.Module: The constructed detector.
"""
if isinstance(config, str):
if isinstance(config, (str, Path)):
config = Config.fromfile(config)
elif not isinstance(config, Config):
raise TypeError('config must be a filename or Config object, '
f'but got {type(config)}')
if cfg_options is not None:
config.merge_from_dict(cfg_options)
elif 'init_cfg' in config.model.backbone:
config.model.backbone.init_cfg = None
convert_SyncBN(config.model)
config.model.train_cfg = None
model = MODELS.build(config.model)
if checkpoint is not None:
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
if 'CLASSES' in checkpoint['meta']:
model.CLASSES = checkpoint['meta']['CLASSES']
dataset_meta = checkpoint['meta'].get('dataset_meta', None)
# save the dataset_meta in the model for convenience
if 'dataset_meta' in checkpoint.get('meta', {}):
# mmdet3d 1.x
model.dataset_meta = dataset_meta
elif 'CLASSES' in checkpoint.get('meta', {}):
# < mmdet3d 1.x
classes = checkpoint['meta']['CLASSES']
model.dataset_meta = {'CLASSES': classes}
if 'PALETTE' in checkpoint.get('meta', {}): # 3D Segmentor
model.dataset_meta['PALETTE'] = checkpoint['meta']['PALETTE']
else:
model.CLASSES = config.class_names
if 'PALETTE' in checkpoint['meta']: # 3D Segmentor
model.PALETTE = checkpoint['meta']['PALETTE']
# < mmdet3d 1.x
model.dataset_meta = {'CLASSES': config.class_names}
if 'PALETTE' in checkpoint.get('meta', {}): # 3D Segmentor
model.dataset_meta['PALETTE'] = checkpoint['meta']['PALETTE']
model.cfg = config # save the config in the model for convenience
if device != 'cpu':
torch.cuda.set_device(device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment