Commit d32a63ca authored by myhloli's avatar myhloli
Browse files

fix(model): improve VRAM detection and handling

- Refactor VRAM detection logic for better readability and efficiency
- Add fallback mechanism for unknown VRAM sizes
- Improve device checking in get_vram function
parent d4cda0a8
...@@ -255,8 +255,9 @@ def may_batch_image_analyze( ...@@ -255,8 +255,9 @@ def may_batch_image_analyze(
torch.npu.set_compile_mode(jit_compile=False) torch.npu.set_compile_mode(jit_compile=False)
if str(device).startswith('npu') or str(device).startswith('cuda'): if str(device).startswith('npu') or str(device).startswith('cuda'):
gpu_memory = int(os.getenv('VIRTUAL_VRAM_SIZE', round(get_vram(device)))) vram = get_vram(device)
if gpu_memory is not None: if vram is not None:
gpu_memory = int(os.getenv('VIRTUAL_VRAM_SIZE', round(vram)))
if gpu_memory >= 16: if gpu_memory >= 16:
batch_ratio = 16 batch_ratio = 16
elif gpu_memory >= 12: elif gpu_memory >= 12:
...@@ -268,6 +269,10 @@ def may_batch_image_analyze( ...@@ -268,6 +269,10 @@ def may_batch_image_analyze(
else: else:
batch_ratio = 1 batch_ratio = 1
logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}') logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
else:
# Default batch_ratio when VRAM can't be determined
batch_ratio = 1
logger.info(f'Could not determine GPU memory, using default batch_ratio: {batch_ratio}')
# doc_analyze_start = time.time() # doc_analyze_start = time.time()
......
...@@ -57,7 +57,7 @@ def clean_vram(device, vram_threshold=8): ...@@ -57,7 +57,7 @@ def clean_vram(device, vram_threshold=8):
def get_vram(device): def get_vram(device):
if torch.cuda.is_available() and device != 'cpu': if torch.cuda.is_available() and str(device).startswith("cuda"):
total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3) # 将字节转换为 GB total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3) # 将字节转换为 GB
return total_memory return total_memory
elif str(device).startswith("npu"): elif str(device).startswith("npu"):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment