"git@developer.sourcefind.cn:change/sglang.git" did not exist on "b39532587bc9328a42f99b522cb9997a210a32b2"
Unverified Commit 6fa8850e authored by ChaimZhu's avatar ChaimZhu Committed by GitHub
Browse files

[Fix]: fix init_model_api (#1739)

parent 4e3cfbe4
...@@ -46,10 +46,7 @@ def main(args): ...@@ -46,10 +46,7 @@ def main(args):
# init visualizer # init visualizer
visualizer = VISUALIZERS.build(model.cfg.visualizer) visualizer = VISUALIZERS.build(model.cfg.visualizer)
visualizer.dataset_meta = { visualizer.dataset_meta = model.dataset_meta
'CLASSES': model.CLASSES,
'PALETTE': model.PALETTE
}
# test a single image # test a single image
result = inference_mono_3d_detector(model, args.img, args.ann, result = inference_mono_3d_detector(model, args.img, args.ann,
......
...@@ -47,10 +47,7 @@ def main(args): ...@@ -47,10 +47,7 @@ def main(args):
# init visualizer # init visualizer
visualizer = VISUALIZERS.build(model.cfg.visualizer) visualizer = VISUALIZERS.build(model.cfg.visualizer)
visualizer.dataset_meta = { visualizer.dataset_meta = model.dataset_meta
'CLASSES': model.CLASSES,
'PALETTE': model.PALETTE
}
# test a single image and point cloud sample # test a single image and point cloud sample
result, data = inference_multi_modality_detector(model, args.pcd, args.img, result, data = inference_multi_modality_detector(model, args.pcd, args.img,
......
...@@ -36,10 +36,7 @@ def main(args): ...@@ -36,10 +36,7 @@ def main(args):
# init visualizer # init visualizer
visualizer = VISUALIZERS.build(model.cfg.visualizer) visualizer = VISUALIZERS.build(model.cfg.visualizer)
visualizer.dataset_meta = { visualizer.dataset_meta = model.dataset_meta
'CLASSES': model.CLASSES,
'PALETTE': model.PALETTE
}
# test a single point cloud sample # test a single point cloud sample
result, data = inference_segmentor(model, args.pcd) result, data = inference_segmentor(model, args.pcd)
......
...@@ -39,9 +39,7 @@ def main(args): ...@@ -39,9 +39,7 @@ def main(args):
# init visualizer # init visualizer
visualizer = VISUALIZERS.build(model.cfg.visualizer) visualizer = VISUALIZERS.build(model.cfg.visualizer)
visualizer.dataset_meta = { visualizer.dataset_meta = model.dataset_meta
'CLASSES': model.CLASSES,
}
# test a single point cloud sample # test a single point cloud sample
result, data = inference_detector(model, args.pcd) result, data = inference_detector(model, args.pcd)
......
...@@ -2,7 +2,8 @@ ...@@ -2,7 +2,8 @@
import warnings import warnings
from copy import deepcopy from copy import deepcopy
from os import path as osp from os import path as osp
from typing import Sequence, Union from pathlib import Path
from typing import Optional, Sequence, Union
import mmengine import mmengine
import numpy as np import numpy as np
...@@ -33,36 +34,60 @@ def convert_SyncBN(config): ...@@ -33,36 +34,60 @@ def convert_SyncBN(config):
convert_SyncBN(config[item]) convert_SyncBN(config[item])
def init_model(config, checkpoint=None, device='cuda:0'): def init_model(config: Union[str, Path, Config],
checkpoint: Optional[str] = None,
device: str = 'cuda:0',
cfg_options: Optional[dict] = None):
"""Initialize a model from config file, which could be a 3D detector or a """Initialize a model from config file, which could be a 3D detector or a
3D segmentor. 3D segmentor.
Args: Args:
config (str or :obj:`mmengine.Config`): Config file path or the config config (str, :obj:`Path`, or :obj:`mmengine.Config`): Config file path,
object. :obj:`Path`, or the config object.
checkpoint (str, optional): Checkpoint path. If left as None, the model checkpoint (str, optional): Checkpoint path. If left as None, the model
will not load any weights. will not load any weights.
device (str): Device to use. device (str): Device to use.
cfg_options (dict, optional): Options to override some settings in
the used config.
Returns: Returns:
nn.Module: The constructed detector. nn.Module: The constructed detector.
""" """
if isinstance(config, str): if isinstance(config, (str, Path)):
config = Config.fromfile(config) config = Config.fromfile(config)
elif not isinstance(config, Config): elif not isinstance(config, Config):
raise TypeError('config must be a filename or Config object, ' raise TypeError('config must be a filename or Config object, '
f'but got {type(config)}') f'but got {type(config)}')
if cfg_options is not None:
config.merge_from_dict(cfg_options)
elif 'init_cfg' in config.model.backbone:
config.model.backbone.init_cfg = None
convert_SyncBN(config.model) convert_SyncBN(config.model)
config.model.train_cfg = None config.model.train_cfg = None
model = MODELS.build(config.model) model = MODELS.build(config.model)
if checkpoint is not None: if checkpoint is not None:
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu') checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
if 'CLASSES' in checkpoint['meta']:
model.CLASSES = checkpoint['meta']['CLASSES'] dataset_meta = checkpoint['meta'].get('dataset_meta', None)
# save the dataset_meta in the model for convenience
if 'dataset_meta' in checkpoint.get('meta', {}):
# mmdet3d 1.x
model.dataset_meta = dataset_meta
elif 'CLASSES' in checkpoint.get('meta', {}):
# < mmdet3d 1.x
classes = checkpoint['meta']['CLASSES']
model.dataset_meta = {'CLASSES': classes}
if 'PALETTE' in checkpoint.get('meta', {}): # 3D Segmentor
model.dataset_meta['PALETTE'] = checkpoint['meta']['PALETTE']
else: else:
model.CLASSES = config.class_names # < mmdet3d 1.x
if 'PALETTE' in checkpoint['meta']: # 3D Segmentor model.dataset_meta = {'CLASSES': config.class_names}
model.PALETTE = checkpoint['meta']['PALETTE']
if 'PALETTE' in checkpoint.get('meta', {}): # 3D Segmentor
model.dataset_meta['PALETTE'] = checkpoint['meta']['PALETTE']
model.cfg = config # save the config in the model for convenience model.cfg = config # save the config in the model for convenience
if device != 'cpu': if device != 'cpu':
torch.cuda.set_device(device) torch.cuda.set_device(device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment