Commit 57cd9bc2 authored by gaclove's avatar gaclove
Browse files

Refactor async calls to synchronous in inference scripts

parent 5103aef7
import os import os
import gradio as gr import gradio as gr
import asyncio
import argparse import argparse
import json import json
import torch import torch
...@@ -449,7 +448,7 @@ def run_inference( ...@@ -449,7 +448,7 @@ def run_inference(
else: else:
runner.config = config runner.config = config
asyncio.run(runner.run_pipeline()) runner.run_pipeline()
del config, args, model_config, quant_model_config del config, args, model_config, quant_model_config
if "dit_quantized_ckpt" in locals(): if "dit_quantized_ckpt" in locals():
......
import os import os
import gradio as gr import gradio as gr
import asyncio
import argparse import argparse
import json import json
import torch import torch
...@@ -451,7 +450,7 @@ def run_inference( ...@@ -451,7 +450,7 @@ def run_inference(
else: else:
runner.config = config runner.config = config
asyncio.run(runner.run_pipeline()) runner.run_pipeline()
del config, args, model_config, quant_model_config del config, args, model_config, quant_model_config
if "dit_quantized_ckpt" in locals(): if "dit_quantized_ckpt" in locals():
......
import asyncio
import argparse import argparse
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -40,7 +39,7 @@ def init_runner(config): ...@@ -40,7 +39,7 @@ def init_runner(config):
return runner return runner
async def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_distill", "wan2.1_causvid", "wan2.1_skyreels_v2_df", "cogvideox", "wan2.1_audio"], default="hunyuan" "--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_distill", "wan2.1_causvid", "wan2.1_skyreels_v2_df", "cogvideox", "wan2.1_audio"], default="hunyuan"
...@@ -85,8 +84,8 @@ async def main(): ...@@ -85,8 +84,8 @@ async def main():
logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}") logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
runner = init_runner(config) runner = init_runner(config)
await runner.run_pipeline() runner.run_pipeline()
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment