"tests/vscode:/vscode.git/clone" did not exist on "e748b3c6e163ce9a61965eb456704a83b855ccc3"
Commit 84ece5f5 authored by helloyongyang's avatar helloyongyang
Browse files

Support server

parent 4e550f37
import argparse
from fastapi import FastAPI
from pydantic import BaseModel
import uvicorn
import json
from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.set_config import set_config
from lightx2v.infer import init_runner
class Message(BaseModel):
prompt: str
negative_prompt: str = ""
image_path: str = ""
save_video_path: str
def get(self, key, default=None):
return getattr(self, key, default)
async def main(message):
runner.set_inputs(message)
runner.run_pipeline()
return {"response": "finished"}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_causal"], default="hunyuan")
parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--config_json", type=str, required=True)
parser.add_argument("--port", type=int, default=8000)
args = parser.parse_args()
print(f"args: {args}")
with ProfilingContext("Init Server Cost"):
config = set_config(args)
print(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
runner = init_runner(config)
app = FastAPI()
@app.post("/v1/local/video/generate")
async def generate_video(message: Message):
response = await main(message)
return response
uvicorn.run(app, host="0.0.0.0", port=config.port)
...@@ -17,15 +17,30 @@ from lightx2v.models.runners.graph_runner import GraphRunner ...@@ -17,15 +17,30 @@ from lightx2v.models.runners.graph_runner import GraphRunner
from lightx2v.common.ops import * from lightx2v.common.ops import *
def init_runner(config):
seed_all(config.seed)
if config.parallel_attn_type:
dist.init_process_group(backend="nccl")
if CHECK_ENABLE_GRAPH_MODE():
default_runner = RUNNER_REGISTER[config.model_cls](config)
runner = GraphRunner(default_runner)
else:
runner = RUNNER_REGISTER[config.model_cls](config)
return runner
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_causal"], default="hunyuan") parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_causal"], default="hunyuan")
parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v") parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
parser.add_argument("--model_path", type=str, required=True) parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--config_json", type=str, required=True)
parser.add_argument("--prompt", type=str, required=True) parser.add_argument("--prompt", type=str, required=True)
parser.add_argument("--negative_prompt", type=str, default="") parser.add_argument("--negative_prompt", type=str, default="")
parser.add_argument("--image_path", type=str, default=None, help="The path to input image file or path for image-to-video (i2v) task") parser.add_argument("--image_path", type=str, default="", help="The path to input image file or path for image-to-video (i2v) task")
parser.add_argument("--config_json", type=str, required=True)
parser.add_argument("--save_video_path", type=str, default="./output_lightx2v.mp4", help="The path to save video path/file") parser.add_argument("--save_video_path", type=str, default="./output_lightx2v.mp4", help="The path to save video path/file")
args = parser.parse_args() args = parser.parse_args()
print(f"args: {args}") print(f"args: {args}")
...@@ -33,15 +48,6 @@ if __name__ == "__main__": ...@@ -33,15 +48,6 @@ if __name__ == "__main__":
with ProfilingContext("Total Cost"): with ProfilingContext("Total Cost"):
config = set_config(args) config = set_config(args)
print(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}") print(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
runner = init_runner(config)
seed_all(config.seed)
if config.parallel_attn_type:
dist.init_process_group(backend="nccl")
if CHECK_ENABLE_GRAPH_MODE():
default_runner = RUNNER_REGISTER[config.model_cls](config)
runner = GraphRunner(default_runner)
else:
runner = RUNNER_REGISTER[config.model_cls](config)
runner.run_pipeline() runner.run_pipeline()
...@@ -11,6 +11,12 @@ class DefaultRunner: ...@@ -11,6 +11,12 @@ class DefaultRunner:
self.config = config self.config = config
self.model, self.text_encoders, self.vae_model, self.image_encoder = self.load_model() self.model, self.text_encoders, self.vae_model, self.image_encoder = self.load_model()
def set_inputs(self, inputs):
self.config["prompt"] = inputs.get("prompt", "")
self.config["negative_prompt"] = inputs.get("negative_prompt", "")
self.config["image_path"] = inputs.get("image_path", "")
self.config["save_video_path"] = inputs.get("save_video_path", "")
def run_input_encoder(self): def run_input_encoder(self):
image_encoder_output = None image_encoder_output = None
if self.config["task"] == "i2v": if self.config["task"] == "i2v":
......
...@@ -28,12 +28,12 @@ def set_config(args): ...@@ -28,12 +28,12 @@ def set_config(args):
config.update({k: v for k, v in vars(args).items()}) config.update({k: v for k, v in vars(args).items()})
config = EasyDict(config) config = EasyDict(config)
with open(args.config_json, "r") as f: with open(config.config_json, "r") as f:
config_json = json.load(f) config_json = json.load(f)
config.update(config_json) config.update(config_json)
if os.path.exists(os.path.join(args.model_path, "config.json")): if os.path.exists(os.path.join(config.model_path, "config.json")):
with open(os.path.join(args.model_path, "config.json"), "r") as f: with open(os.path.join(config.model_path, "config.json"), "r") as f:
model_config = json.load(f) model_config = json.load(f)
config.update(model_config) config.update(model_config)
......
import requests
url = "http://localhost:8000/v1/local/video/generate"
message = {
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
"negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"image_path": "",
"save_video_path": "./output_lightx2v_wan_t2v_ap4.mp4", # It is best to set it to an absolute path.
}
response = requests.post(url, json=message)
print(f"response: {response.json()}")
...@@ -28,7 +28,7 @@ export PYTHONPATH=${lightx2v_path}:$PYTHONPATH ...@@ -28,7 +28,7 @@ export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false export ENABLE_GRAPH_MODE=false
python -m lightx2v \ python -m lightx2v.infer \
--model_cls hunyuan \ --model_cls hunyuan \
--task i2v \ --task i2v \
--model_path $model_path \ --model_path $model_path \
......
...@@ -28,7 +28,7 @@ export PYTHONPATH=${lightx2v_path}:$PYTHONPATH ...@@ -28,7 +28,7 @@ export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false export ENABLE_GRAPH_MODE=false
python -m lightx2v \ python -m lightx2v.infer \
--model_cls hunyuan \ --model_cls hunyuan \
--task t2v \ --task t2v \
--model_path $model_path \ --model_path $model_path \
......
...@@ -27,7 +27,7 @@ export PYTHONPATH=${lightx2v_path}:$PYTHONPATH ...@@ -27,7 +27,7 @@ export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true export ENABLE_PROFILING_DEBUG=true
torchrun --nproc_per_node=4 ${lightx2v_path}/lightx2v/__main__.py \ torchrun --nproc_per_node=4 ${lightx2v_path}/lightx2v/infer.py \
--model_cls hunyuan \ --model_cls hunyuan \
--task t2v \ --task t2v \
--model_path $model_path \ --model_path $model_path \
...@@ -35,7 +35,7 @@ torchrun --nproc_per_node=4 ${lightx2v_path}/lightx2v/__main__.py \ ...@@ -35,7 +35,7 @@ torchrun --nproc_per_node=4 ${lightx2v_path}/lightx2v/__main__.py \
--prompt "A cat walks on the grass, realistic style." \ --prompt "A cat walks on the grass, realistic style." \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_hunyuan_t2v_dist_ulysses.mp4 --save_video_path ${lightx2v_path}/save_results/output_lightx2v_hunyuan_t2v_dist_ulysses.mp4
torchrun --nproc_per_node=4 ${lightx2v_path}/lightx2v/__main__.py \ torchrun --nproc_per_node=4 ${lightx2v_path}/lightx2v/infer.py \
--model_cls hunyuan \ --model_cls hunyuan \
--task t2v \ --task t2v \
--model_path $model_path \ --model_path $model_path \
......
...@@ -27,7 +27,7 @@ export PYTHONPATH=${lightx2v_path}:$PYTHONPATH ...@@ -27,7 +27,7 @@ export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true export ENABLE_PROFILING_DEBUG=true
python -m lightx2v \ python -m lightx2v.infer \
--model_cls hunyuan \ --model_cls hunyuan \
--task t2v \ --task t2v \
--model_path $model_path \ --model_path $model_path \
......
...@@ -28,7 +28,7 @@ export PYTHONPATH=${lightx2v_path}:$PYTHONPATH ...@@ -28,7 +28,7 @@ export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false export ENABLE_GRAPH_MODE=false
python -m lightx2v \ python -m lightx2v.infer \
--model_cls wan2.1 \ --model_cls wan2.1 \
--task i2v \ --task i2v \
--model_path $model_path \ --model_path $model_path \
......
...@@ -28,7 +28,7 @@ export PYTHONPATH=${lightx2v_path}:$PYTHONPATH ...@@ -28,7 +28,7 @@ export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true export ENABLE_PROFILING_DEBUG=true
python -m lightx2v \ python -m lightx2v.infer \
--model_cls wan2.1 \ --model_cls wan2.1 \
--task i2v \ --task i2v \
--model_path $model_path \ --model_path $model_path \
......
...@@ -28,7 +28,7 @@ export PYTHONPATH=${lightx2v_path}:$PYTHONPATH ...@@ -28,7 +28,7 @@ export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false export ENABLE_GRAPH_MODE=false
python -m lightx2v \ python -m lightx2v.infer \
--model_cls wan2.1 \ --model_cls wan2.1 \
--task t2v \ --task t2v \
--model_path $model_path \ --model_path $model_path \
......
...@@ -28,7 +28,7 @@ export PYTHONPATH=${lightx2v_path}:$PYTHONPATH ...@@ -28,7 +28,7 @@ export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false export ENABLE_GRAPH_MODE=false
python -m lightx2v \ python -m lightx2v.infer \
--model_cls wan2.1_causal \ --model_cls wan2.1_causal \
--task t2v \ --task t2v \
--model_path $model_path \ --model_path $model_path \
......
...@@ -27,7 +27,7 @@ export PYTHONPATH=${lightx2v_path}:$PYTHONPATH ...@@ -27,7 +27,7 @@ export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true export ENABLE_PROFILING_DEBUG=true
torchrun --nproc_per_node=4 ${lightx2v_path}/lightx2v/__main__.py \ torchrun --nproc_per_node=4 ${lightx2v_path}/lightx2v/infer.py \
--model_cls wan2.1 \ --model_cls wan2.1 \
--task t2v \ --task t2v \
--model_path $model_path \ --model_path $model_path \
...@@ -45,7 +45,7 @@ torchrun --nproc_per_node=4 ${lightx2v_path}/lightx2v/__main__.py \ ...@@ -45,7 +45,7 @@ torchrun --nproc_per_node=4 ${lightx2v_path}/lightx2v/__main__.py \
--parallel_vae \ --parallel_vae \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v_dist_ring.mp4 --save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v_dist_ring.mp4
torchrun --nproc_per_node=4 ${lightx2v_path}/lightx2v/__main__.py \ torchrun --nproc_per_node=4 ${lightx2v_path}/lightx2v/infer.py \
--model_cls wan2.1 \ --model_cls wan2.1 \
--task t2v \ --task t2v \
--model_path $model_path \ --model_path $model_path \
......
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using defalt value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
python -m lightx2v.api_server \
--model_cls wan2.1 \
--task t2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/wan_t2v.json \
--port 8000
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment