Commit 45e7fbd2 authored by myhloli's avatar myhloli
Browse files

feat(model-config): Unify all device selections through a single YAML file

parent bc0f6932
......@@ -19,8 +19,8 @@ from magic_pdf.model.pek_sub_modules.post_process import get_croped_image, latex
from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
def layout_model_init(weight, config_file):
model = Layoutlmv3_Predictor(weight, config_file)
def layout_model_init(weight, config_file, device):
model = Layoutlmv3_Predictor(weight, config_file, device)
return model
......@@ -89,7 +89,8 @@ class CustomPEKModel:
# 初始化layout模型
self.layout_model = layout_model_init(
os.path.join(root_dir, self.configs['weights']['layout']),
os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")
os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml"),
device=self.device
)
# 初始化公式识别
if self.apply_formula:
......
......@@ -61,16 +61,21 @@ def add_vit_config(cfg):
_C.SOLVER.GRADIENT_ACCUMULATION_STEPS = 1
def setup(args):
def setup(args, device):
"""
Create configs and perform basic setups.
"""
cfg = get_cfg()
# add_coat_config(cfg)
add_vit_config(cfg)
cfg.merge_from_file(args.config_file)
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.2 # set threshold for this model
cfg.merge_from_list(args.opts)
# 使用统一的device配置
cfg.MODEL.DEVICE = device
cfg.freeze()
default_setup(cfg, args)
......@@ -101,7 +106,7 @@ class DotDict(dict):
class Layoutlmv3_Predictor(object):
def __init__(self, weights, config_file):
def __init__(self, weights, config_file, device):
layout_args = {
"config_file": config_file,
"resume": False,
......@@ -114,7 +119,7 @@ class Layoutlmv3_Predictor(object):
}
layout_args = DotDict(layout_args)
cfg = setup(layout_args)
cfg = setup(layout_args, device)
self.mapping = ["title", "plain text", "abandon", "figure", "figure_caption", "table", "table_caption",
"table_footnote", "isolate_formula", "formula_caption"]
MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).thing_classes = self.mapping
......
......@@ -69,7 +69,7 @@ MODEL:
FREEZE_AT: 2
NAME: build_vit_fpn_backbone
CONFIG_PATH: ''
DEVICE: cpu
DEVICE: cuda
FPN:
FUSE_TYPE: sum
IN_FEATURES:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment