Commit b1fe9d4f authored by myhloli's avatar myhloli
Browse files

feat(gradio_app): implement dynamic concurrency limit based on VRAM

- Add get_concurrency_limit function to calculate concurrency limit based on VRAM
- Update clean_vram function and rename to get_vram for better clarity
- Apply concurrency limit to the to_markdown function in the Gradio app
parent fdf47155
...@@ -42,10 +42,16 @@ def get_res_list_from_layout_res(layout_res): ...@@ -42,10 +42,16 @@ def get_res_list_from_layout_res(layout_res):
def clean_vram(device, vram_threshold=8): def clean_vram(device, vram_threshold=8):
total_memory = get_vram(device)
if total_memory <= vram_threshold:
gc_start = time.time()
clean_memory()
gc_time = round(time.time() - gc_start, 2)
logger.info(f"gc time: {gc_time}")
def get_vram(device):
if torch.cuda.is_available() and device != 'cpu': if torch.cuda.is_available() and device != 'cpu':
total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3) # 将字节转换为 GB total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3) # 将字节转换为 GB
if total_memory <= vram_threshold: return total_memory
gc_start = time.time() return 0
clean_memory() \ No newline at end of file
gc_time = round(time.time() - gc_start, 2)
logger.info(f"gc time: {gc_time}")
\ No newline at end of file
...@@ -14,7 +14,9 @@ from gradio_pdf import PDF ...@@ -14,7 +14,9 @@ from gradio_pdf import PDF
from loguru import logger from loguru import logger
from magic_pdf.data.data_reader_writer import FileBasedDataReader from magic_pdf.data.data_reader_writer import FileBasedDataReader
from magic_pdf.libs.config_reader import get_device
from magic_pdf.libs.hash_utils import compute_sha256 from magic_pdf.libs.hash_utils import compute_sha256
from magic_pdf.model.sub_modules.model_utils import get_vram
from magic_pdf.tools.common import do_parse, prepare_env from magic_pdf.tools.common import do_parse, prepare_env
...@@ -183,6 +185,15 @@ def to_pdf(file_path): ...@@ -183,6 +185,15 @@ def to_pdf(file_path):
return tmp_file_path return tmp_file_path
def get_concurrency_limit(vram_threshold=7.5):
vram = get_vram(device = get_device())
concurrency_limit = int(vram // vram_threshold)
if concurrency_limit < 1:
concurrency_limit = 1
# logger.info(f'concurrency_limit: {concurrency_limit}')
return concurrency_limit
if __name__ == '__main__': if __name__ == '__main__':
with gr.Blocks() as demo: with gr.Blocks() as demo:
gr.HTML(header) gr.HTML(header)
...@@ -219,7 +230,7 @@ if __name__ == '__main__': ...@@ -219,7 +230,7 @@ if __name__ == '__main__':
md_text = gr.TextArea(lines=45, show_copy_button=True) md_text = gr.TextArea(lines=45, show_copy_button=True)
file.upload(fn=to_pdf, inputs=file, outputs=pdf_show) file.upload(fn=to_pdf, inputs=file, outputs=pdf_show)
change_bu.click(fn=to_markdown, inputs=[pdf_show, max_pages, is_ocr, layout_mode, formula_enable, table_enable, language], change_bu.click(fn=to_markdown, inputs=[pdf_show, max_pages, is_ocr, layout_mode, formula_enable, table_enable, language],
outputs=[md, md_text, output_file, pdf_show]) outputs=[md, md_text, output_file, pdf_show], concurrency_limit=get_concurrency_limit())
clear_bu.add([file, md, pdf_show, md_text, output_file, is_ocr, table_enable, language]) clear_bu.add([file, md, pdf_show, md_text, output_file, is_ocr, table_enable, language])
demo.launch(server_name='0.0.0.0') demo.launch(server_name='0.0.0.0')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment