Commit d32a63ca authored by myhloli's avatar myhloli
Browse files

fix(model): improve VRAM detection and handling

- Refactor VRAM detection logic for better readability and efficiency
- Add fallback mechanism for unknown VRAM sizes
- Improve device checking in get_vram function
parent d4cda0a8
......@@ -255,8 +255,9 @@ def may_batch_image_analyze(
torch.npu.set_compile_mode(jit_compile=False)
if str(device).startswith('npu') or str(device).startswith('cuda'):
gpu_memory = int(os.getenv('VIRTUAL_VRAM_SIZE', round(get_vram(device))))
if gpu_memory is not None:
vram = get_vram(device)
if vram is not None:
gpu_memory = int(os.getenv('VIRTUAL_VRAM_SIZE', round(vram)))
if gpu_memory >= 16:
batch_ratio = 16
elif gpu_memory >= 12:
......@@ -268,6 +269,10 @@ def may_batch_image_analyze(
else:
batch_ratio = 1
logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
else:
# Default batch_ratio when VRAM can't be determined
batch_ratio = 1
logger.info(f'Could not determine GPU memory, using default batch_ratio: {batch_ratio}')
# doc_analyze_start = time.time()
......
......@@ -57,7 +57,7 @@ def clean_vram(device, vram_threshold=8):
def get_vram(device):
if torch.cuda.is_available() and device != 'cpu':
if torch.cuda.is_available() and str(device).startswith("cuda"):
total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3) # 将字节转换为 GB
return total_memory
elif str(device).startswith("npu"):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment