Commit 57cd9bc2 authored by gaclove's avatar gaclove
Browse files

Refactor async calls to synchronous in inference scripts

parent 5103aef7
import os
import gradio as gr
import asyncio
import argparse
import json
import torch
......@@ -449,7 +448,7 @@ def run_inference(
else:
runner.config = config
asyncio.run(runner.run_pipeline())
runner.run_pipeline()
del config, args, model_config, quant_model_config
if "dit_quantized_ckpt" in locals():
......
import os
import gradio as gr
import asyncio
import argparse
import json
import torch
......@@ -451,7 +450,7 @@ def run_inference(
else:
runner.config = config
asyncio.run(runner.run_pipeline())
runner.run_pipeline()
del config, args, model_config, quant_model_config
if "dit_quantized_ckpt" in locals():
......
import asyncio
import argparse
import torch
import torch.distributed as dist
......@@ -40,7 +39,7 @@ def init_runner(config):
return runner
async def main():
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_distill", "wan2.1_causvid", "wan2.1_skyreels_v2_df", "cogvideox", "wan2.1_audio"], default="hunyuan"
......@@ -85,8 +84,8 @@ async def main():
logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
runner = init_runner(config)
await runner.run_pipeline()
runner.run_pipeline()
if __name__ == "__main__":
asyncio.run(main())
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment