Commit e821eebf authored by helloyongyang's avatar helloyongyang
Browse files

简化infer入口

parent c745e2c4
...@@ -418,7 +418,6 @@ def run_inference( ...@@ -418,7 +418,6 @@ def run_inference(
config.update({k: v for k, v in vars(args).items()}) config.update({k: v for k, v in vars(args).items()})
config = EasyDict(config) config = EasyDict(config)
config["mode"] = "infer"
config.update(model_config) config.update(model_config)
config.update(quant_model_config) config.update(quant_model_config)
......
...@@ -420,7 +420,6 @@ def run_inference( ...@@ -420,7 +420,6 @@ def run_inference(
config.update({k: v for k, v in vars(args).items()}) config.update({k: v for k, v in vars(args).items()})
config = EasyDict(config) config = EasyDict(config)
config["mode"] = "infer"
config.update(model_config) config.update(model_config)
config.update(quant_model_config) config.update(quant_model_config)
......
...@@ -121,7 +121,6 @@ if __name__ == "__main__": ...@@ -121,7 +121,6 @@ if __name__ == "__main__":
with ProfilingContext("Init Server Cost"): with ProfilingContext("Init Server Cost"):
config = set_config(args) config = set_config(args)
config["mode"] = "split_server"
logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}") logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
runner = DiTRunner(config) runner = DiTRunner(config)
......
...@@ -116,7 +116,6 @@ if __name__ == "__main__": ...@@ -116,7 +116,6 @@ if __name__ == "__main__":
with ProfilingContext("Init Server Cost"): with ProfilingContext("Init Server Cost"):
config = set_config(args) config = set_config(args)
config["mode"] = "split_server"
logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}") logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
runner = ImageEncoderRunner(config) runner = ImageEncoderRunner(config)
......
...@@ -119,7 +119,6 @@ if __name__ == "__main__": ...@@ -119,7 +119,6 @@ if __name__ == "__main__":
with ProfilingContext("Init Server Cost"): with ProfilingContext("Init Server Cost"):
config = set_config(args) config = set_config(args)
config["mode"] = "split_server"
logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}") logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
runner = TextEncoderRunner(config) runner = TextEncoderRunner(config)
......
...@@ -168,7 +168,6 @@ if __name__ == "__main__": ...@@ -168,7 +168,6 @@ if __name__ == "__main__":
with ProfilingContext("Init Server Cost"): with ProfilingContext("Init Server Cost"):
config = set_config(args) config = set_config(args)
config["mode"] = "split_server"
logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}") logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
runner = VAERunner(config) runner = VAERunner(config)
......
...@@ -42,8 +42,9 @@ def init_runner(config): ...@@ -42,8 +42,9 @@ def init_runner(config):
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_distill", "wan2.1_causvid", "wan2.1_skyreels_v2_df", "cogvideox", "wan2.1_audio"], default="hunyuan" "--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_distill", "wan2.1_causvid", "wan2.1_skyreels_v2_df", "cogvideox", "wan2.1_audio"], default="wan2.1"
) )
parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v") parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
parser.add_argument("--model_path", type=str, required=True) parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--config_json", type=str, required=True) parser.add_argument("--config_json", type=str, required=True)
...@@ -51,36 +52,17 @@ def main(): ...@@ -51,36 +52,17 @@ def main():
parser.add_argument("--prompt", type=str, default="", help="The input prompt for text-to-video generation") parser.add_argument("--prompt", type=str, default="", help="The input prompt for text-to-video generation")
parser.add_argument("--negative_prompt", type=str, default="") parser.add_argument("--negative_prompt", type=str, default="")
parser.add_argument("--lora_path", type=str, default="", help="The lora file path")
parser.add_argument("--lora_strength", type=float, default=1.0, help="The strength for the lora (default: 1.0)") parser.add_argument("--image_path", type=str, default="", help="The path to input image file for image-to-video (i2v) task")
parser.add_argument("--prompt_path", type=str, default="", help="The path to input prompt file") parser.add_argument("--audio_path", type=str, default="", help="The path to input audio file for audio-to-video (a2v) task")
parser.add_argument("--audio_path", type=str, default="", help="The path to input audio file")
parser.add_argument("--image_path", type=str, default="", help="The path to input image file or path for image-to-video (i2v) task")
parser.add_argument("--save_video_path", type=str, default="./output_lightx2v.mp4", help="The path to save video path/file") parser.add_argument("--save_video_path", type=str, default="./output_lightx2v.mp4", help="The path to save video path/file")
args = parser.parse_args() args = parser.parse_args()
if args.prompt_path:
try:
with open(args.prompt_path, "r", encoding="utf-8") as f:
args.prompt = f.read().strip()
logger.info(f"从文件 {args.prompt_path} 读取到prompt: {args.prompt}")
except FileNotFoundError:
logger.error(f"找不到prompt文件: {args.prompt_path}")
raise
except Exception as e:
logger.error(f"读取prompt文件时出错: {e}")
raise
if args.lora_path:
args.lora_configs = [{"path": args.lora_path, "strength": args.lora_strength}]
delattr(args, "lora_path")
delattr(args, "lora_strength")
logger.info(f"args: {args}") logger.info(f"args: {args}")
with ProfilingContext("Total Cost"): with ProfilingContext("Total Cost"):
config = set_config(args) config = set_config(args)
config["mode"] = "infer"
logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}") logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
runner = init_runner(config) runner = init_runner(config)
......
...@@ -90,7 +90,6 @@ def _distributed_inference_worker(rank, world_size, master_addr, master_port, ar ...@@ -90,7 +90,6 @@ def _distributed_inference_worker(rank, world_size, master_addr, master_port, ar
# Initialize configuration and model # Initialize configuration and model
config = set_config(args) config = set_config(args)
config["mode"] = "server"
logger.info(f"Rank {rank} config: {config}") logger.info(f"Rank {rank} config: {config}")
runner = init_runner(config) runner = init_runner(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment