"vscode:/vscode.git/clone" did not exist on "cdc56ef6c1c6f359de87c5f78a66316723557d5d"
Commit e64d4fed authored by myhloli's avatar myhloli
Browse files

refactor(table): add device configuration for Unitable model

- Import get_device function from magic_pdf.libs.config_reader- Update RapidTableModel initialization to include device parameter for Unitable model
parent 48c20514
......@@ -5,6 +5,8 @@ from loguru import logger
from rapid_table import RapidTable, RapidTableInput
from rapid_table.main import ModelType
from magic_pdf.libs.config_reader import get_device
class RapidTableModel(object):
def __init__(self, ocr_engine, table_sub_model_name):
......@@ -13,7 +15,7 @@ class RapidTableModel(object):
input_args = RapidTableInput()
elif table_sub_model_name in sub_model_list:
if torch.cuda.is_available() and table_sub_model_name == "unitable":
input_args = RapidTableInput(model_type=table_sub_model_name, use_cuda=True)
input_args = RapidTableInput(model_type=table_sub_model_name, use_cuda=True, device=get_device())
else:
input_args = RapidTableInput(model_type=table_sub_model_name)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment