".github/vscode:/vscode.git/clone" did not exist on "1070e1a38a41637400361a694966350fe790d5a4"
Unverified Commit 07eaa2d7 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #1959 from myhloli/dev

Dev push
parents 132c16ad 2f40fa7d
...@@ -257,13 +257,13 @@ def may_batch_image_analyze( ...@@ -257,13 +257,13 @@ def may_batch_image_analyze(
if str(device).startswith('npu') or str(device).startswith('cuda'): if str(device).startswith('npu') or str(device).startswith('cuda'):
gpu_memory = int(os.getenv('VIRTUAL_VRAM_SIZE', round(get_vram(device)))) gpu_memory = int(os.getenv('VIRTUAL_VRAM_SIZE', round(get_vram(device))))
if gpu_memory is not None: if gpu_memory is not None:
if gpu_memory >= 20: if gpu_memory >= 16:
batch_ratio = 16 batch_ratio = 16
elif gpu_memory >= 15: elif gpu_memory >= 12:
batch_ratio = 8 batch_ratio = 8
elif gpu_memory >= 10: elif gpu_memory >= 8:
batch_ratio = 4 batch_ratio = 4
elif gpu_memory >= 7: elif gpu_memory >= 6:
batch_ratio = 2 batch_ratio = 2
else: else:
batch_ratio = 1 batch_ratio = 1
......
...@@ -333,8 +333,14 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang ...@@ -333,8 +333,14 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang
def model_init(model_name: str): def model_init(model_name: str):
from transformers import LayoutLMv3ForTokenClassification from transformers import LayoutLMv3ForTokenClassification
device = torch.device(get_device()) device_name = get_device()
bf_16_support = False
if device_name.startswith("cuda"):
bf_16_support = torch.cuda.is_bf16_supported()
elif device_name.startswith("mps"):
bf_16_support = True
device = torch.device(device_name)
if model_name == 'layoutreader': if model_name == 'layoutreader':
# 检测modelscope的缓存目录是否存在 # 检测modelscope的缓存目录是否存在
layoutreader_model_dir = get_local_layoutreader_model_dir() layoutreader_model_dir = get_local_layoutreader_model_dir()
...@@ -349,7 +355,10 @@ def model_init(model_name: str): ...@@ -349,7 +355,10 @@ def model_init(model_name: str):
model = LayoutLMv3ForTokenClassification.from_pretrained( model = LayoutLMv3ForTokenClassification.from_pretrained(
'hantian/layoutreader' 'hantian/layoutreader'
) )
model.to(device).eval() if bf_16_support:
model.to(device).eval().bfloat16()
else:
model.to(device).eval()
else: else:
logger.error('model name not allow') logger.error('model name not allow')
exit(1) exit(1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment