Commit 9b00f988 authored by myhloli's avatar myhloli
Browse files

refactor(magic_pdf): remove bfloat16 support checks and usage

- Remove supports_bfloat16 variable and related checks
- Remove model.bfloat16() call for LayoutLMv3ForTokenClassification
- Simplify device selection logic
parent 315adbce
...@@ -341,21 +341,14 @@ def model_init(model_name: str): ...@@ -341,21 +341,14 @@ def model_init(model_name: str):
device = get_device() device = get_device()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device('cuda') device = torch.device('cuda')
if torch.cuda.is_bf16_supported():
supports_bfloat16 = True
else:
supports_bfloat16 = False
elif str(device).startswith("npu"): elif str(device).startswith("npu"):
import torch_npu import torch_npu
if torch_npu.npu.is_available(): if torch_npu.npu.is_available():
device = torch.device('npu') device = torch.device('npu')
supports_bfloat16 = False
else: else:
device = torch.device('cpu') device = torch.device('cpu')
supports_bfloat16 = False
else: else:
device = torch.device('cpu') device = torch.device('cpu')
supports_bfloat16 = False
if model_name == 'layoutreader': if model_name == 'layoutreader':
# 检测modelscope的缓存目录是否存在 # 检测modelscope的缓存目录是否存在
...@@ -371,9 +364,6 @@ def model_init(model_name: str): ...@@ -371,9 +364,6 @@ def model_init(model_name: str):
model = LayoutLMv3ForTokenClassification.from_pretrained( model = LayoutLMv3ForTokenClassification.from_pretrained(
'hantian/layoutreader' 'hantian/layoutreader'
) )
# 检查设备是否支持 bfloat16
if supports_bfloat16:
model.bfloat16()
model.to(device).eval() model.to(device).eval()
else: else:
logger.error('model name not allow') logger.error('model name not allow')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment