"git@developer.sourcefind.cn:hehl2/torchaudio.git" did not exist on "55f61d1524d2b51bdcb3be51404abc91609c9665"
Commit 7d2dfc80 authored by liukaiwen's avatar liukaiwen
Browse files

Merge branch 'dev' into dev-table-model-update

parents a0eff3be 6d571e2e
...@@ -6,9 +6,18 @@ ...@@ -6,9 +6,18 @@
"models-dir":"/tmp/models", "models-dir":"/tmp/models",
"layoutreader-model-dir":"/tmp/layoutreader", "layoutreader-model-dir":"/tmp/layoutreader",
"device-mode":"cpu", "device-mode":"cpu",
"layout-config": {
"model": "layoutlmv3"
},
"formula-config": {
"mfd_model": "yolo_v8_mfd",
"mfr_model": "unimernet_small",
"enable": true
},
"table-config": { "table-config": {
"model": "TableMaster", "model": "tablemaster",
"is_table_recog_enable": false, "enable": false,
"max_time": 400 "max_time": 400
} },
"config_version": "1.0.0"
} }
\ No newline at end of file
import enum
class SupportedPdfParseMethod(enum.Enum):
OCR = 'ocr'
TXT = 'txt'
class FileNotExisted(Exception):
def __init__(self, path):
self.path = path
def __str__(self):
return f'File {self.path} does not exist.'
class InvalidConfig(Exception):
def __init__(self, msg):
self.msg = msg
def __str__(self):
return f'Invalid config: {self.msg}'
class InvalidParams(Exception):
def __init__(self, msg):
self.msg = msg
def __str__(self):
return f'Invalid params: {self.msg}'
class EmptyData(Exception):
def __init__(self, msg):
self.msg = msg
def __str__(self):
return f'Empty data: {self.msg}'
from magic_pdf.data.data_reader_writer.filebase import \
FileBasedDataReader # noqa: F401
from magic_pdf.data.data_reader_writer.filebase import \
FileBasedDataWriter # noqa: F401
from magic_pdf.data.data_reader_writer.multi_bucket_s3 import \
MultiBucketS3DataReader # noqa: F401
from magic_pdf.data.data_reader_writer.multi_bucket_s3 import \
MultiBucketS3DataWriter # noqa: F401
from magic_pdf.data.data_reader_writer.s3 import S3DataReader # noqa: F401
from magic_pdf.data.data_reader_writer.s3 import S3DataWriter # noqa: F401
from magic_pdf.data.data_reader_writer.base import DataReader # noqa: F401
from magic_pdf.data.data_reader_writer.base import DataWriter # noqa: F401
\ No newline at end of file
from abc import ABC, abstractmethod
class DataReader(ABC):
def read(self, path: str) -> bytes:
"""Read the file.
Args:
path (str): file path to read
Returns:
bytes: the content of the file
"""
return self.read_at(path)
@abstractmethod
def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
"""Read the file at offset and limit.
Args:
path (str): the file path
offset (int, optional): the number of bytes skipped. Defaults to 0.
limit (int, optional): the length of bytes want to read. Defaults to -1.
Returns:
bytes: the content of the file
"""
pass
class DataWriter(ABC):
@abstractmethod
def write(self, path: str, data: bytes) -> None:
"""Write the data to the file.
Args:
path (str): the target file where to write
data (bytes): the data want to write
"""
pass
def write_string(self, path: str, data: str) -> None:
"""Write the data to file, the data will be encoded to bytes.
Args:
path (str): the target file where to write
data (str): the data want to write
"""
self.write(path, data.encode())
import os
from magic_pdf.data.data_reader_writer.base import DataReader, DataWriter
class FileBasedDataReader(DataReader):
def __init__(self, parent_dir: str = ''):
"""Initialized with parent_dir.
Args:
parent_dir (str, optional): the parent directory that may be used within methods. Defaults to ''.
"""
self._parent_dir = parent_dir
def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
"""Read at offset and limit.
Args:
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
offset (int, optional): the number of bytes skipped. Defaults to 0.
limit (int, optional): the length of bytes want to read. Defaults to -1.
Returns:
bytes: the content of file
"""
fn_path = path
if not os.path.isabs(fn_path) and len(self._parent_dir) > 0:
fn_path = os.path.join(self._parent_dir, path)
with open(fn_path, 'rb') as f:
f.seek(offset)
if limit == -1:
return f.read()
else:
return f.read(limit)
class FileBasedDataWriter(DataWriter):
def __init__(self, parent_dir: str = '') -> None:
"""Initialized with parent_dir.
Args:
parent_dir (str, optional): the parent directory that may be used within methods. Defaults to ''.
"""
self._parent_dir = parent_dir
def write(self, path: str, data: bytes) -> None:
"""Write file with data.
Args:
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
data (bytes): the data want to write
"""
fn_path = path
if not os.path.isabs(fn_path) and len(self._parent_dir) > 0:
fn_path = os.path.join(self._parent_dir, path)
with open(fn_path, 'wb') as f:
f.write(data)
from magic_pdf.config.exceptions import InvalidConfig, InvalidParams
from magic_pdf.data.data_reader_writer.base import DataReader, DataWriter
from magic_pdf.data.io.s3 import S3Reader, S3Writer
from magic_pdf.data.schemas import S3Config
from magic_pdf.libs.path_utils import (parse_s3_range_params, parse_s3path,
remove_non_official_s3_args)
class MultiS3Mixin:
def __init__(self, default_bucket: str, s3_configs: list[S3Config]):
"""Initialized with multiple s3 configs.
Args:
default_bucket (str): the default bucket name of the relative path
s3_configs (list[S3Config]): list of s3 configs, the bucket_name must be unique in the list.
Raises:
InvalidConfig: default bucket config not in s3_configs
InvalidConfig: bucket name not unique in s3_configs
InvalidConfig: default bucket must be provided
"""
if len(default_bucket) == 0:
raise InvalidConfig('default_bucket must be provided')
found_default_bucket_config = False
for conf in s3_configs:
if conf.bucket_name == default_bucket:
found_default_bucket_config = True
break
if not found_default_bucket_config:
raise InvalidConfig(
f'default_bucket: {default_bucket} config must be provided in s3_configs: {s3_configs}'
)
uniq_bucket = set([conf.bucket_name for conf in s3_configs])
if len(uniq_bucket) != len(s3_configs):
raise InvalidConfig(
f'the bucket_name in s3_configs: {s3_configs} must be unique'
)
self.default_bucket = default_bucket
self.s3_configs = s3_configs
self._s3_clients_h: dict = {}
class MultiBucketS3DataReader(DataReader, MultiS3Mixin):
def read(self, path: str) -> bytes:
"""Read the path from s3, select diffect bucket client for each request
based on the path, also support range read.
Args:
path (str): the s3 path of file, the path must be in the format of s3://bucket_name/path?offset,limit
for example: s3://bucket_name/path?0,100
Returns:
bytes: the content of s3 file
"""
may_range_params = parse_s3_range_params(path)
if may_range_params is None or 2 != len(may_range_params):
byte_start, byte_len = 0, -1
else:
byte_start, byte_len = int(may_range_params[0]), int(may_range_params[1])
path = remove_non_official_s3_args(path)
return self.read_at(path, byte_start, byte_len)
def __get_s3_client(self, bucket_name: str):
if bucket_name not in set([conf.bucket_name for conf in self.s3_configs]):
raise InvalidParams(
f'bucket name: {bucket_name} not found in s3_configs: {self.s3_configs}'
)
if bucket_name not in self._s3_clients_h:
conf = next(
filter(lambda conf: conf.bucket_name == bucket_name, self.s3_configs)
)
self._s3_clients_h[bucket_name] = S3Reader(
bucket_name,
conf.access_key,
conf.secret_key,
conf.endpoint_url,
conf.addressing_style,
)
return self._s3_clients_h[bucket_name]
def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
"""Read the file with offset and limit, select diffect bucket client
for each request based on the path.
Args:
path (str): the file path
offset (int, optional): the number of bytes skipped. Defaults to 0.
limit (int, optional): the number of bytes want to read. Defaults to -1 which means infinite.
Returns:
bytes: the file content
"""
if path.startswith('s3://'):
bucket_name, path = parse_s3path(path)
s3_reader = self.__get_s3_client(bucket_name)
else:
s3_reader = self.__get_s3_client(self.default_bucket)
return s3_reader.read_at(path, offset, limit)
class MultiBucketS3DataWriter(DataWriter, MultiS3Mixin):
def __get_s3_client(self, bucket_name: str):
if bucket_name not in set([conf.bucket_name for conf in self.s3_configs]):
raise InvalidParams(
f'bucket name: {bucket_name} not found in s3_configs: {self.s3_configs}'
)
if bucket_name not in self._s3_clients_h:
conf = next(
filter(lambda conf: conf.bucket_name == bucket_name, self.s3_configs)
)
self._s3_clients_h[bucket_name] = S3Writer(
bucket_name,
conf.access_key,
conf.secret_key,
conf.endpoint_url,
conf.addressing_style,
)
return self._s3_clients_h[bucket_name]
def write(self, path: str, data: bytes) -> None:
"""Write file with data, also select diffect bucket client for each
request based on the path.
Args:
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
data (bytes): the data want to write
"""
if path.startswith('s3://'):
bucket_name, path = parse_s3path(path)
s3_writer = self.__get_s3_client(bucket_name)
else:
s3_writer = self.__get_s3_client(self.default_bucket)
return s3_writer.write(path, data)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment