Commit 45e7fbd2 authored by myhloli's avatar myhloli
Browse files

feat(model-config): Unify all device selections through a single YAML file

parent bc0f6932
...@@ -19,8 +19,8 @@ from magic_pdf.model.pek_sub_modules.post_process import get_croped_image, latex ...@@ -19,8 +19,8 @@ from magic_pdf.model.pek_sub_modules.post_process import get_croped_image, latex
from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
def layout_model_init(weight, config_file): def layout_model_init(weight, config_file, device):
model = Layoutlmv3_Predictor(weight, config_file) model = Layoutlmv3_Predictor(weight, config_file, device)
return model return model
...@@ -89,7 +89,8 @@ class CustomPEKModel: ...@@ -89,7 +89,8 @@ class CustomPEKModel:
# 初始化layout模型 # 初始化layout模型
self.layout_model = layout_model_init( self.layout_model = layout_model_init(
os.path.join(root_dir, self.configs['weights']['layout']), os.path.join(root_dir, self.configs['weights']['layout']),
os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml") os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml"),
device=self.device
) )
# 初始化公式识别 # 初始化公式识别
if self.apply_formula: if self.apply_formula:
......
...@@ -61,16 +61,21 @@ def add_vit_config(cfg): ...@@ -61,16 +61,21 @@ def add_vit_config(cfg):
_C.SOLVER.GRADIENT_ACCUMULATION_STEPS = 1 _C.SOLVER.GRADIENT_ACCUMULATION_STEPS = 1
def setup(args): def setup(args, device):
""" """
Create configs and perform basic setups. Create configs and perform basic setups.
""" """
cfg = get_cfg() cfg = get_cfg()
# add_coat_config(cfg) # add_coat_config(cfg)
add_vit_config(cfg) add_vit_config(cfg)
cfg.merge_from_file(args.config_file) cfg.merge_from_file(args.config_file)
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.2 # set threshold for this model cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.2 # set threshold for this model
cfg.merge_from_list(args.opts) cfg.merge_from_list(args.opts)
# 使用统一的device配置
cfg.MODEL.DEVICE = device
cfg.freeze() cfg.freeze()
default_setup(cfg, args) default_setup(cfg, args)
...@@ -101,7 +106,7 @@ class DotDict(dict): ...@@ -101,7 +106,7 @@ class DotDict(dict):
class Layoutlmv3_Predictor(object): class Layoutlmv3_Predictor(object):
def __init__(self, weights, config_file): def __init__(self, weights, config_file, device):
layout_args = { layout_args = {
"config_file": config_file, "config_file": config_file,
"resume": False, "resume": False,
...@@ -114,7 +119,7 @@ class Layoutlmv3_Predictor(object): ...@@ -114,7 +119,7 @@ class Layoutlmv3_Predictor(object):
} }
layout_args = DotDict(layout_args) layout_args = DotDict(layout_args)
cfg = setup(layout_args) cfg = setup(layout_args, device)
self.mapping = ["title", "plain text", "abandon", "figure", "figure_caption", "table", "table_caption", self.mapping = ["title", "plain text", "abandon", "figure", "figure_caption", "table", "table_caption",
"table_footnote", "isolate_formula", "formula_caption"] "table_footnote", "isolate_formula", "formula_caption"]
MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).thing_classes = self.mapping MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).thing_classes = self.mapping
......
...@@ -69,7 +69,7 @@ MODEL: ...@@ -69,7 +69,7 @@ MODEL:
FREEZE_AT: 2 FREEZE_AT: 2
NAME: build_vit_fpn_backbone NAME: build_vit_fpn_backbone
CONFIG_PATH: '' CONFIG_PATH: ''
DEVICE: cpu DEVICE: cuda
FPN: FPN:
FUSE_TYPE: sum FUSE_TYPE: sum
IN_FEATURES: IN_FEATURES:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment