Commit 5e08fd1a authored by PengGao's avatar PengGao Committed by GitHub
Browse files

Enhance BaseRunner with model protocols and type hints and fix lora load (#117)

* Enhance BaseRunner with model protocols and type hints and fix lora load

* Add support for multiple LoRAs configuration and enhance lora_strength argument handling

* doc: style fix
parent 7a35c418
......@@ -39,9 +39,29 @@ python -m lightx2v.infer \
--model_path /path/to/model \
--config_json /path/to/config.json \
--lora_path /path/to/your/lora.safetensors \
--lora_strength 0.8 \
--prompt "Your prompt here"
```
### Multiple LoRAs Configuration
To use multiple LoRAs with different strengths, specify them in the config JSON file:
```json
{
"lora_configs": [
{
"path": "/path/to/first_lora.safetensors",
"strength": 0.8
},
{
"path": "/path/to/second_lora.safetensors",
"strength": 0.5
}
]
}
```
### Supported LoRA Formats
LightX2V supports multiple LoRA weight naming conventions:
......
......@@ -38,9 +38,29 @@ python -m lightx2v.infer \
--model_path /path/to/model \
--config_json /path/to/config.json \
--lora_path /path/to/your/lora.safetensors \
--lora_strength 0.8 \
--prompt "Your prompt here"
```
### 多LoRA配置
要使用多个具有不同强度的LoRA,请在配置JSON文件中指定:
```json
{
"lora_configs": [
{
"path": "/path/to/first_lora.safetensors",
"strength": 0.8
},
{
"path": "/path/to/second_lora.safetensors",
"strength": 0.5
}
]
}
```
### 支持的 LoRA 格式
LightX2V 支持多种 LoRA 权重命名约定:
......
......@@ -49,6 +49,7 @@ def main():
parser.add_argument("--split", action="store_true")
parser.add_argument("--lora_path", type=str, required=False, default=None)
parser.add_argument("--lora_strength", type=float, default=1.0, help="The strength for the lora (default: 1.0)")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--nproc_per_node", type=int, default=1, help="Number of processes per node for distributed inference")
......
......@@ -53,6 +53,7 @@ async def main():
parser.add_argument("--prompt", type=str, default="", help="The input prompt for text-to-video generation")
parser.add_argument("--negative_prompt", type=str, default="")
parser.add_argument("--lora_path", type=str, default="", help="The lora file path")
parser.add_argument("--lora_strength", type=float, default=1.0, help="The strength for the lora (default: 1.0)")
parser.add_argument("--prompt_path", type=str, default="", help="The path to input prompt file")
parser.add_argument("--audio_path", type=str, default="", help="The path to input audio file")
parser.add_argument("--image_path", type=str, default="", help="The path to input image file or path for image-to-video (i2v) task")
......@@ -71,6 +72,11 @@ async def main():
logger.error(f"读取prompt文件时出错: {e}")
raise
if args.lora_path:
args.lora_configs = [{"path": args.lora_path, "strength": args.lora_strength}]
delattr(args, "lora_path")
delattr(args, "lora_strength")
logger.info(f"args: {args}")
with ProfilingContext("Total Cost"):
......
from abc import ABC, abstractmethod
from typing import Any, Dict, Tuple, Optional
from typing import Any, Dict, Tuple, Optional, Union, List, Protocol
from lightx2v.utils.utils import save_videos_grid
class TransformerModel(Protocol):
"""Protocol for transformer models"""
def set_scheduler(self, scheduler: Any) -> None: ...
def scheduler(self) -> Any: ...
class TextEncoderModel(Protocol):
"""Protocol for text encoder models"""
def infer(self, texts: List[str], config: Dict[str, Any]) -> Any: ...
class ImageEncoderModel(Protocol):
"""Protocol for image encoder models"""
def encode(self, image: Any) -> Any: ...
class VAEModel(Protocol):
"""Protocol for VAE models"""
def encode(self, image: Any) -> Tuple[Any, Dict[str, Any]]: ...
def decode(self, latents: Any, generator: Optional[Any] = None, config: Optional[Dict[str, Any]] = None) -> Any: ...
class BaseRunner(ABC):
"""Abstract base class for all Runners
......@@ -13,34 +39,34 @@ class BaseRunner(ABC):
self.config = config
@abstractmethod
def load_transformer(self):
def load_transformer(self) -> TransformerModel:
"""Load transformer model
Returns:
Loaded model instance
Loaded transformer model instance
"""
pass
@abstractmethod
def load_text_encoder(self):
def load_text_encoder(self) -> Union[TextEncoderModel, List[TextEncoderModel]]:
"""Load text encoder
Returns:
Text encoder instance or list of instances
Text encoder instance or list of text encoder instances
"""
pass
@abstractmethod
def load_image_encoder(self):
def load_image_encoder(self) -> Optional[ImageEncoderModel]:
"""Load image encoder
Returns:
Image encoder instance
Image encoder instance or None if not needed
"""
pass
@abstractmethod
def load_vae(self) -> Tuple[Any, Any]:
def load_vae(self) -> Tuple[VAEModel, VAEModel]:
"""Load VAE encoder and decoder
Returns:
......@@ -49,7 +75,7 @@ class BaseRunner(ABC):
pass
@abstractmethod
def run_image_encoder(self, img):
def run_image_encoder(self, img: Any) -> Any:
"""Run image encoder
Args:
......@@ -61,19 +87,19 @@ class BaseRunner(ABC):
pass
@abstractmethod
def run_vae_encoder(self, img):
def run_vae_encoder(self, img: Any) -> Tuple[Any, Dict[str, Any]]:
"""Run VAE encoder
Args:
img: Input image
Returns:
VAE encoding result and additional parameters
Tuple of VAE encoding result and additional parameters
"""
pass
@abstractmethod
def run_text_encoder(self, prompt: str, img: Optional[Any] = None):
def run_text_encoder(self, prompt: str, img: Optional[Any] = None) -> Any:
"""Run text encoder
Args:
......@@ -86,7 +112,7 @@ class BaseRunner(ABC):
pass
@abstractmethod
def get_encoder_output_i2v(self, clip_encoder_out, vae_encode_out, text_encoder_output, img):
def get_encoder_output_i2v(self, clip_encoder_out: Any, vae_encode_out: Any, text_encoder_output: Any, img: Any) -> Dict[str, Any]:
"""Combine encoder outputs for i2v task
Args:
......@@ -101,7 +127,7 @@ class BaseRunner(ABC):
pass
@abstractmethod
def init_scheduler(self):
def init_scheduler(self) -> None:
"""Initialize scheduler"""
pass
......@@ -115,7 +141,7 @@ class BaseRunner(ABC):
"""
return {}
def save_video_func(self, images):
def save_video_func(self, images: Any) -> None:
"""Save video implementation
Subclasses can override this method to customize save logic
......@@ -125,7 +151,7 @@ class BaseRunner(ABC):
"""
save_videos_grid(images, self.config.get("save_video_path", "./output.mp4"), n_rows=1, fps=self.config.get("fps", 8))
def load_vae_decoder(self):
def load_vae_decoder(self) -> VAEModel:
"""Load VAE decoder
Default implementation: get decoder from load_vae method
......
......@@ -163,7 +163,10 @@ class DefaultRunner(BaseRunner):
text_encoder_output = self.run_text_encoder(prompt, None)
torch.cuda.empty_cache()
gc.collect()
return {"text_encoder_output": text_encoder_output, "image_encoder_output": None}
return {
"text_encoder_output": text_encoder_output,
"image_encoder_output": None,
}
@ProfilingContext("Run DiT")
def _run_dit_local(self, kwargs):
......@@ -242,7 +245,13 @@ class DefaultRunner(BaseRunner):
for url in self.config["sub_servers"]["prompt_enhancer"]:
response = requests.get(f"{url}/v1/local/prompt_enhancer/generate/service_status").json()
if response["service_status"] == "idle":
response = requests.post(f"{url}/v1/local/prompt_enhancer/generate", json={"task_id": generate_task_id(), "prompt": self.config["prompt"]})
response = requests.post(
f"{url}/v1/local/prompt_enhancer/generate",
json={
"task_id": generate_task_id(),
"prompt": self.config["prompt"],
},
)
enhanced_prompt = response.json()["output"]
logger.info(f"Enhanced prompt: {enhanced_prompt}")
return enhanced_prompt
......@@ -251,17 +260,36 @@ class DefaultRunner(BaseRunner):
tasks = []
img_byte = self.image_transporter.prepare_image(img)
tasks.append(
asyncio.create_task(self.post_task(task_type="image_encoder", urls=self.config["sub_servers"]["image_encoder"], message={"task_id": generate_task_id(), "img": img_byte}, device="cuda"))
asyncio.create_task(
self.post_task(
task_type="image_encoder",
urls=self.config["sub_servers"]["image_encoder"],
message={"task_id": generate_task_id(), "img": img_byte},
device="cuda",
)
)
)
tasks.append(
asyncio.create_task(self.post_task(task_type="vae_model/encoder", urls=self.config["sub_servers"]["vae_model"], message={"task_id": generate_task_id(), "img": img_byte}, device="cuda"))
asyncio.create_task(
self.post_task(
task_type="vae_model/encoder",
urls=self.config["sub_servers"]["vae_model"],
message={"task_id": generate_task_id(), "img": img_byte},
device="cuda",
)
)
)
tasks.append(
asyncio.create_task(
self.post_task(
task_type="text_encoders",
urls=self.config["sub_servers"]["text_encoders"],
message={"task_id": generate_task_id(), "text": prompt, "img": img_byte, "n_prompt": n_prompt},
message={
"task_id": generate_task_id(),
"text": prompt,
"img": img_byte,
"n_prompt": n_prompt,
},
device="cuda",
)
)
......@@ -277,7 +305,12 @@ class DefaultRunner(BaseRunner):
self.post_task(
task_type="text_encoders",
urls=self.config["sub_servers"]["text_encoders"],
message={"task_id": generate_task_id(), "text": prompt, "img": None, "n_prompt": n_prompt},
message={
"task_id": generate_task_id(),
"text": prompt,
"img": None,
"n_prompt": n_prompt,
},
device="cuda",
)
)
......@@ -290,7 +323,11 @@ class DefaultRunner(BaseRunner):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
n_prompt = self.config.get("negative_prompt", "")
img = Image.open(self.config["image_path"]).convert("RGB")
clip_encoder_out, vae_encode_out, text_encoder_output = await self.post_encoders_i2v(prompt, img, n_prompt)
(
clip_encoder_out,
vae_encode_out,
text_encoder_output,
) = await self.post_encoders_i2v(prompt, img, n_prompt)
torch.cuda.empty_cache()
gc.collect()
return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output, img)
......@@ -301,7 +338,10 @@ class DefaultRunner(BaseRunner):
text_encoder_output = await self.post_encoders_t2v(prompt, n_prompt)
torch.cuda.empty_cache()
gc.collect()
return {"text_encoder_output": text_encoder_output, "image_encoder_output": None}
return {
"text_encoder_output": text_encoder_output,
"image_encoder_output": None,
}
async def _run_dit_server(self, kwargs):
if self.inputs.get("image_encoder_output", None) is not None:
......@@ -309,7 +349,11 @@ class DefaultRunner(BaseRunner):
dit_output = await self.post_task(
task_type="dit",
urls=self.config["sub_servers"]["dit"],
message={"task_id": generate_task_id(), "inputs": self.tensor_transporter.prepare_tensor(self.inputs), "kwargs": self.tensor_transporter.prepare_tensor(kwargs)},
message={
"task_id": generate_task_id(),
"inputs": self.tensor_transporter.prepare_tensor(self.inputs),
"kwargs": self.tensor_transporter.prepare_tensor(kwargs),
},
device="cuda",
)
return dit_output, None
......@@ -318,7 +362,10 @@ class DefaultRunner(BaseRunner):
images = await self.post_task(
task_type="vae_model/decoder",
urls=self.config["sub_servers"]["vae_model"],
message={"task_id": generate_task_id(), "latents": self.tensor_transporter.prepare_tensor(latents)},
message={
"task_id": generate_task_id(),
"latents": self.tensor_transporter.prepare_tensor(latents),
},
device="cpu",
)
return images
......
......@@ -278,12 +278,15 @@ class WanAudioRunner(WanRunner):
def load_transformer(self):
base_model = WanAudioModel(self.config.model_path, self.config, self.init_device)
if self.config.lora_path:
if self.config.get("lora_configs") and self.config.lora_configs:
assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False)
lora_wrapper = WanLoraWrapper(base_model)
lora_name = lora_wrapper.load_lora(self.config.lora_path)
lora_wrapper.apply_lora(lora_name, self.config.strength_model)
logger.info(f"Loaded LoRA: {lora_name}")
for lora_config in self.config.lora_configs:
lora_path = lora_config["path"]
strength = lora_config.get("strength", 1.0)
lora_name = lora_wrapper.load_lora(lora_path)
lora_wrapper.apply_lora(lora_name, strength)
logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")
return base_model
......
......@@ -31,17 +31,19 @@ class WanCausVidRunner(WanRunner):
self.num_fragments = self.config.num_fragments
def load_transformer(self):
if self.config.lora_path:
if self.config.get("lora_configs") and self.config.lora_configs:
model = WanModel(
self.config.model_path,
self.config,
self.init_device,
)
lora_wrapper = WanLoraWrapper(model)
for lora_path in self.config.lora_path:
for lora_config in self.config.lora_configs:
lora_path = lora_config["path"]
strength = lora_config.get("strength", 1.0)
lora_name = lora_wrapper.load_lora(lora_path)
lora_wrapper.apply_lora(lora_name, self.config.strength_model)
logger.info(f"Loaded LoRA: {lora_name}")
lora_wrapper.apply_lora(lora_name, strength)
logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")
else:
model = WanCausVidModel(self.config.model_path, self.config, self.init_device)
return model
......
......@@ -24,17 +24,19 @@ class WanDistillRunner(WanRunner):
super().__init__(config)
def load_transformer(self):
if self.config.lora_path:
if self.config.get("lora_configs") and self.config.lora_configs:
model = WanModel(
self.config.model_path,
self.config,
self.init_device,
)
lora_wrapper = WanLoraWrapper(model)
for lora_path in self.config.lora_path:
for lora_config in self.config.lora_configs:
lora_path = lora_config["path"]
strength = lora_config.get("strength", 1.0)
lora_name = lora_wrapper.load_lora(lora_path)
lora_wrapper.apply_lora(lora_name, self.config.strength_model)
logger.info(f"Loaded LoRA: {lora_name}")
lora_wrapper.apply_lora(lora_name, strength)
logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")
else:
model = WanDistillModel(self.config.model_path, self.config, self.init_device)
return model
......
......@@ -36,13 +36,15 @@ class WanRunner(DefaultRunner):
self.config,
self.init_device,
)
if self.config.lora_path:
if self.config.get("lora_configs") and self.config.lora_configs:
assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False)
lora_wrapper = WanLoraWrapper(model)
for lora_path in self.config.lora_path:
for lora_config in self.config.lora_configs:
lora_path = lora_config["path"]
strength = lora_config.get("strength", 1.0)
lora_name = lora_wrapper.load_lora(lora_path)
lora_wrapper.apply_lora(lora_name, self.config.strength_model)
logger.info(f"Loaded LoRA: {lora_name}")
lora_wrapper.apply_lora(lora_name, strength)
logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")
return model
def load_image_encoder(self):
......
......@@ -186,6 +186,12 @@ class DistributedInferenceService:
self.is_running = False
def start_distributed_inference(self, args) -> bool:
if hasattr(args, "lora_path") and args.lora_path:
args.lora_configs = [{"path": args.lora_path, "strength": getattr(args, "lora_strength", 1.0)}]
delattr(args, "lora_path")
if hasattr(args, "lora_strength"):
delattr(args, "lora_strength")
self.args = args
if self.is_running:
logger.warning("Distributed inference service is already running")
......
......@@ -17,8 +17,7 @@ def get_default_config():
"teacache_thresh": 0.26,
"use_ret_steps": False,
"use_bfloat16": True,
"lora_path": None,
"strength_model": 1.0,
"lora_configs": None, # List of dicts with 'path' and 'strength' keys
"mm_config": {},
"use_prompt_enhancer": False,
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment