Commit 5e08fd1a authored by PengGao's avatar PengGao Committed by GitHub
Browse files

Enhance BaseRunner with model protocols and type hints and fix lora load (#117)

* Enhance BaseRunner with model protocols and type hints and fix lora load

* Add support for multiple LoRAs configuration and enhance lora_strength argument handling

* doc: style fix
parent 7a35c418
...@@ -39,9 +39,29 @@ python -m lightx2v.infer \ ...@@ -39,9 +39,29 @@ python -m lightx2v.infer \
--model_path /path/to/model \ --model_path /path/to/model \
--config_json /path/to/config.json \ --config_json /path/to/config.json \
--lora_path /path/to/your/lora.safetensors \ --lora_path /path/to/your/lora.safetensors \
--lora_strength 0.8 \
--prompt "Your prompt here" --prompt "Your prompt here"
``` ```
### Multiple LoRAs Configuration
To use multiple LoRAs with different strengths, specify them in the config JSON file:
```json
{
"lora_configs": [
{
"path": "/path/to/first_lora.safetensors",
"strength": 0.8
},
{
"path": "/path/to/second_lora.safetensors",
"strength": 0.5
}
]
}
```
### Supported LoRA Formats ### Supported LoRA Formats
LightX2V supports multiple LoRA weight naming conventions: LightX2V supports multiple LoRA weight naming conventions:
......
...@@ -38,9 +38,29 @@ python -m lightx2v.infer \ ...@@ -38,9 +38,29 @@ python -m lightx2v.infer \
--model_path /path/to/model \ --model_path /path/to/model \
--config_json /path/to/config.json \ --config_json /path/to/config.json \
--lora_path /path/to/your/lora.safetensors \ --lora_path /path/to/your/lora.safetensors \
--lora_strength 0.8 \
--prompt "Your prompt here" --prompt "Your prompt here"
``` ```
### 多LoRA配置
要使用多个具有不同强度的LoRA,请在配置JSON文件中指定:
```json
{
"lora_configs": [
{
"path": "/path/to/first_lora.safetensors",
"strength": 0.8
},
{
"path": "/path/to/second_lora.safetensors",
"strength": 0.5
}
]
}
```
### 支持的 LoRA 格式 ### 支持的 LoRA 格式
LightX2V 支持多种 LoRA 权重命名约定: LightX2V 支持多种 LoRA 权重命名约定:
......
...@@ -49,6 +49,7 @@ def main(): ...@@ -49,6 +49,7 @@ def main():
parser.add_argument("--split", action="store_true") parser.add_argument("--split", action="store_true")
parser.add_argument("--lora_path", type=str, required=False, default=None) parser.add_argument("--lora_path", type=str, required=False, default=None)
parser.add_argument("--lora_strength", type=float, default=1.0, help="The strength for the lora (default: 1.0)")
parser.add_argument("--port", type=int, default=8000) parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--seed", type=int, default=42) parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--nproc_per_node", type=int, default=1, help="Number of processes per node for distributed inference") parser.add_argument("--nproc_per_node", type=int, default=1, help="Number of processes per node for distributed inference")
......
...@@ -53,6 +53,7 @@ async def main(): ...@@ -53,6 +53,7 @@ async def main():
parser.add_argument("--prompt", type=str, default="", help="The input prompt for text-to-video generation") parser.add_argument("--prompt", type=str, default="", help="The input prompt for text-to-video generation")
parser.add_argument("--negative_prompt", type=str, default="") parser.add_argument("--negative_prompt", type=str, default="")
parser.add_argument("--lora_path", type=str, default="", help="The lora file path") parser.add_argument("--lora_path", type=str, default="", help="The lora file path")
parser.add_argument("--lora_strength", type=float, default=1.0, help="The strength for the lora (default: 1.0)")
parser.add_argument("--prompt_path", type=str, default="", help="The path to input prompt file") parser.add_argument("--prompt_path", type=str, default="", help="The path to input prompt file")
parser.add_argument("--audio_path", type=str, default="", help="The path to input audio file") parser.add_argument("--audio_path", type=str, default="", help="The path to input audio file")
parser.add_argument("--image_path", type=str, default="", help="The path to input image file or path for image-to-video (i2v) task") parser.add_argument("--image_path", type=str, default="", help="The path to input image file or path for image-to-video (i2v) task")
...@@ -71,6 +72,11 @@ async def main(): ...@@ -71,6 +72,11 @@ async def main():
logger.error(f"读取prompt文件时出错: {e}") logger.error(f"读取prompt文件时出错: {e}")
raise raise
if args.lora_path:
args.lora_configs = [{"path": args.lora_path, "strength": args.lora_strength}]
delattr(args, "lora_path")
delattr(args, "lora_strength")
logger.info(f"args: {args}") logger.info(f"args: {args}")
with ProfilingContext("Total Cost"): with ProfilingContext("Total Cost"):
......
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, Tuple, Optional from typing import Any, Dict, Tuple, Optional, Union, List, Protocol
from lightx2v.utils.utils import save_videos_grid from lightx2v.utils.utils import save_videos_grid
class TransformerModel(Protocol):
"""Protocol for transformer models"""
def set_scheduler(self, scheduler: Any) -> None: ...
def scheduler(self) -> Any: ...
class TextEncoderModel(Protocol):
"""Protocol for text encoder models"""
def infer(self, texts: List[str], config: Dict[str, Any]) -> Any: ...
class ImageEncoderModel(Protocol):
"""Protocol for image encoder models"""
def encode(self, image: Any) -> Any: ...
class VAEModel(Protocol):
"""Protocol for VAE models"""
def encode(self, image: Any) -> Tuple[Any, Dict[str, Any]]: ...
def decode(self, latents: Any, generator: Optional[Any] = None, config: Optional[Dict[str, Any]] = None) -> Any: ...
class BaseRunner(ABC): class BaseRunner(ABC):
"""Abstract base class for all Runners """Abstract base class for all Runners
...@@ -13,34 +39,34 @@ class BaseRunner(ABC): ...@@ -13,34 +39,34 @@ class BaseRunner(ABC):
self.config = config self.config = config
@abstractmethod @abstractmethod
def load_transformer(self): def load_transformer(self) -> TransformerModel:
"""Load transformer model """Load transformer model
Returns: Returns:
Loaded model instance Loaded transformer model instance
""" """
pass pass
@abstractmethod @abstractmethod
def load_text_encoder(self): def load_text_encoder(self) -> Union[TextEncoderModel, List[TextEncoderModel]]:
"""Load text encoder """Load text encoder
Returns: Returns:
Text encoder instance or list of instances Text encoder instance or list of text encoder instances
""" """
pass pass
@abstractmethod @abstractmethod
def load_image_encoder(self): def load_image_encoder(self) -> Optional[ImageEncoderModel]:
"""Load image encoder """Load image encoder
Returns: Returns:
Image encoder instance Image encoder instance or None if not needed
""" """
pass pass
@abstractmethod @abstractmethod
def load_vae(self) -> Tuple[Any, Any]: def load_vae(self) -> Tuple[VAEModel, VAEModel]:
"""Load VAE encoder and decoder """Load VAE encoder and decoder
Returns: Returns:
...@@ -49,7 +75,7 @@ class BaseRunner(ABC): ...@@ -49,7 +75,7 @@ class BaseRunner(ABC):
pass pass
@abstractmethod @abstractmethod
def run_image_encoder(self, img): def run_image_encoder(self, img: Any) -> Any:
"""Run image encoder """Run image encoder
Args: Args:
...@@ -61,19 +87,19 @@ class BaseRunner(ABC): ...@@ -61,19 +87,19 @@ class BaseRunner(ABC):
pass pass
@abstractmethod @abstractmethod
def run_vae_encoder(self, img): def run_vae_encoder(self, img: Any) -> Tuple[Any, Dict[str, Any]]:
"""Run VAE encoder """Run VAE encoder
Args: Args:
img: Input image img: Input image
Returns: Returns:
VAE encoding result and additional parameters Tuple of VAE encoding result and additional parameters
""" """
pass pass
@abstractmethod @abstractmethod
def run_text_encoder(self, prompt: str, img: Optional[Any] = None): def run_text_encoder(self, prompt: str, img: Optional[Any] = None) -> Any:
"""Run text encoder """Run text encoder
Args: Args:
...@@ -86,7 +112,7 @@ class BaseRunner(ABC): ...@@ -86,7 +112,7 @@ class BaseRunner(ABC):
pass pass
@abstractmethod @abstractmethod
def get_encoder_output_i2v(self, clip_encoder_out, vae_encode_out, text_encoder_output, img): def get_encoder_output_i2v(self, clip_encoder_out: Any, vae_encode_out: Any, text_encoder_output: Any, img: Any) -> Dict[str, Any]:
"""Combine encoder outputs for i2v task """Combine encoder outputs for i2v task
Args: Args:
...@@ -101,7 +127,7 @@ class BaseRunner(ABC): ...@@ -101,7 +127,7 @@ class BaseRunner(ABC):
pass pass
@abstractmethod @abstractmethod
def init_scheduler(self): def init_scheduler(self) -> None:
"""Initialize scheduler""" """Initialize scheduler"""
pass pass
...@@ -115,7 +141,7 @@ class BaseRunner(ABC): ...@@ -115,7 +141,7 @@ class BaseRunner(ABC):
""" """
return {} return {}
def save_video_func(self, images): def save_video_func(self, images: Any) -> None:
"""Save video implementation """Save video implementation
Subclasses can override this method to customize save logic Subclasses can override this method to customize save logic
...@@ -125,7 +151,7 @@ class BaseRunner(ABC): ...@@ -125,7 +151,7 @@ class BaseRunner(ABC):
""" """
save_videos_grid(images, self.config.get("save_video_path", "./output.mp4"), n_rows=1, fps=self.config.get("fps", 8)) save_videos_grid(images, self.config.get("save_video_path", "./output.mp4"), n_rows=1, fps=self.config.get("fps", 8))
def load_vae_decoder(self): def load_vae_decoder(self) -> VAEModel:
"""Load VAE decoder """Load VAE decoder
Default implementation: get decoder from load_vae method Default implementation: get decoder from load_vae method
......
...@@ -163,7 +163,10 @@ class DefaultRunner(BaseRunner): ...@@ -163,7 +163,10 @@ class DefaultRunner(BaseRunner):
text_encoder_output = self.run_text_encoder(prompt, None) text_encoder_output = self.run_text_encoder(prompt, None)
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
return {"text_encoder_output": text_encoder_output, "image_encoder_output": None} return {
"text_encoder_output": text_encoder_output,
"image_encoder_output": None,
}
@ProfilingContext("Run DiT") @ProfilingContext("Run DiT")
def _run_dit_local(self, kwargs): def _run_dit_local(self, kwargs):
...@@ -242,7 +245,13 @@ class DefaultRunner(BaseRunner): ...@@ -242,7 +245,13 @@ class DefaultRunner(BaseRunner):
for url in self.config["sub_servers"]["prompt_enhancer"]: for url in self.config["sub_servers"]["prompt_enhancer"]:
response = requests.get(f"{url}/v1/local/prompt_enhancer/generate/service_status").json() response = requests.get(f"{url}/v1/local/prompt_enhancer/generate/service_status").json()
if response["service_status"] == "idle": if response["service_status"] == "idle":
response = requests.post(f"{url}/v1/local/prompt_enhancer/generate", json={"task_id": generate_task_id(), "prompt": self.config["prompt"]}) response = requests.post(
f"{url}/v1/local/prompt_enhancer/generate",
json={
"task_id": generate_task_id(),
"prompt": self.config["prompt"],
},
)
enhanced_prompt = response.json()["output"] enhanced_prompt = response.json()["output"]
logger.info(f"Enhanced prompt: {enhanced_prompt}") logger.info(f"Enhanced prompt: {enhanced_prompt}")
return enhanced_prompt return enhanced_prompt
...@@ -251,17 +260,36 @@ class DefaultRunner(BaseRunner): ...@@ -251,17 +260,36 @@ class DefaultRunner(BaseRunner):
tasks = [] tasks = []
img_byte = self.image_transporter.prepare_image(img) img_byte = self.image_transporter.prepare_image(img)
tasks.append( tasks.append(
asyncio.create_task(self.post_task(task_type="image_encoder", urls=self.config["sub_servers"]["image_encoder"], message={"task_id": generate_task_id(), "img": img_byte}, device="cuda")) asyncio.create_task(
self.post_task(
task_type="image_encoder",
urls=self.config["sub_servers"]["image_encoder"],
message={"task_id": generate_task_id(), "img": img_byte},
device="cuda",
)
)
) )
tasks.append( tasks.append(
asyncio.create_task(self.post_task(task_type="vae_model/encoder", urls=self.config["sub_servers"]["vae_model"], message={"task_id": generate_task_id(), "img": img_byte}, device="cuda")) asyncio.create_task(
self.post_task(
task_type="vae_model/encoder",
urls=self.config["sub_servers"]["vae_model"],
message={"task_id": generate_task_id(), "img": img_byte},
device="cuda",
)
)
) )
tasks.append( tasks.append(
asyncio.create_task( asyncio.create_task(
self.post_task( self.post_task(
task_type="text_encoders", task_type="text_encoders",
urls=self.config["sub_servers"]["text_encoders"], urls=self.config["sub_servers"]["text_encoders"],
message={"task_id": generate_task_id(), "text": prompt, "img": img_byte, "n_prompt": n_prompt}, message={
"task_id": generate_task_id(),
"text": prompt,
"img": img_byte,
"n_prompt": n_prompt,
},
device="cuda", device="cuda",
) )
) )
...@@ -277,7 +305,12 @@ class DefaultRunner(BaseRunner): ...@@ -277,7 +305,12 @@ class DefaultRunner(BaseRunner):
self.post_task( self.post_task(
task_type="text_encoders", task_type="text_encoders",
urls=self.config["sub_servers"]["text_encoders"], urls=self.config["sub_servers"]["text_encoders"],
message={"task_id": generate_task_id(), "text": prompt, "img": None, "n_prompt": n_prompt}, message={
"task_id": generate_task_id(),
"text": prompt,
"img": None,
"n_prompt": n_prompt,
},
device="cuda", device="cuda",
) )
) )
...@@ -290,7 +323,11 @@ class DefaultRunner(BaseRunner): ...@@ -290,7 +323,11 @@ class DefaultRunner(BaseRunner):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"] prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
n_prompt = self.config.get("negative_prompt", "") n_prompt = self.config.get("negative_prompt", "")
img = Image.open(self.config["image_path"]).convert("RGB") img = Image.open(self.config["image_path"]).convert("RGB")
clip_encoder_out, vae_encode_out, text_encoder_output = await self.post_encoders_i2v(prompt, img, n_prompt) (
clip_encoder_out,
vae_encode_out,
text_encoder_output,
) = await self.post_encoders_i2v(prompt, img, n_prompt)
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output, img) return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output, img)
...@@ -301,7 +338,10 @@ class DefaultRunner(BaseRunner): ...@@ -301,7 +338,10 @@ class DefaultRunner(BaseRunner):
text_encoder_output = await self.post_encoders_t2v(prompt, n_prompt) text_encoder_output = await self.post_encoders_t2v(prompt, n_prompt)
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
return {"text_encoder_output": text_encoder_output, "image_encoder_output": None} return {
"text_encoder_output": text_encoder_output,
"image_encoder_output": None,
}
async def _run_dit_server(self, kwargs): async def _run_dit_server(self, kwargs):
if self.inputs.get("image_encoder_output", None) is not None: if self.inputs.get("image_encoder_output", None) is not None:
...@@ -309,7 +349,11 @@ class DefaultRunner(BaseRunner): ...@@ -309,7 +349,11 @@ class DefaultRunner(BaseRunner):
dit_output = await self.post_task( dit_output = await self.post_task(
task_type="dit", task_type="dit",
urls=self.config["sub_servers"]["dit"], urls=self.config["sub_servers"]["dit"],
message={"task_id": generate_task_id(), "inputs": self.tensor_transporter.prepare_tensor(self.inputs), "kwargs": self.tensor_transporter.prepare_tensor(kwargs)}, message={
"task_id": generate_task_id(),
"inputs": self.tensor_transporter.prepare_tensor(self.inputs),
"kwargs": self.tensor_transporter.prepare_tensor(kwargs),
},
device="cuda", device="cuda",
) )
return dit_output, None return dit_output, None
...@@ -318,7 +362,10 @@ class DefaultRunner(BaseRunner): ...@@ -318,7 +362,10 @@ class DefaultRunner(BaseRunner):
images = await self.post_task( images = await self.post_task(
task_type="vae_model/decoder", task_type="vae_model/decoder",
urls=self.config["sub_servers"]["vae_model"], urls=self.config["sub_servers"]["vae_model"],
message={"task_id": generate_task_id(), "latents": self.tensor_transporter.prepare_tensor(latents)}, message={
"task_id": generate_task_id(),
"latents": self.tensor_transporter.prepare_tensor(latents),
},
device="cpu", device="cpu",
) )
return images return images
......
...@@ -278,12 +278,15 @@ class WanAudioRunner(WanRunner): ...@@ -278,12 +278,15 @@ class WanAudioRunner(WanRunner):
def load_transformer(self): def load_transformer(self):
base_model = WanAudioModel(self.config.model_path, self.config, self.init_device) base_model = WanAudioModel(self.config.model_path, self.config, self.init_device)
if self.config.lora_path: if self.config.get("lora_configs") and self.config.lora_configs:
assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False) assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False)
lora_wrapper = WanLoraWrapper(base_model) lora_wrapper = WanLoraWrapper(base_model)
lora_name = lora_wrapper.load_lora(self.config.lora_path) for lora_config in self.config.lora_configs:
lora_wrapper.apply_lora(lora_name, self.config.strength_model) lora_path = lora_config["path"]
logger.info(f"Loaded LoRA: {lora_name}") strength = lora_config.get("strength", 1.0)
lora_name = lora_wrapper.load_lora(lora_path)
lora_wrapper.apply_lora(lora_name, strength)
logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")
return base_model return base_model
......
...@@ -31,17 +31,19 @@ class WanCausVidRunner(WanRunner): ...@@ -31,17 +31,19 @@ class WanCausVidRunner(WanRunner):
self.num_fragments = self.config.num_fragments self.num_fragments = self.config.num_fragments
def load_transformer(self): def load_transformer(self):
if self.config.lora_path: if self.config.get("lora_configs") and self.config.lora_configs:
model = WanModel( model = WanModel(
self.config.model_path, self.config.model_path,
self.config, self.config,
self.init_device, self.init_device,
) )
lora_wrapper = WanLoraWrapper(model) lora_wrapper = WanLoraWrapper(model)
for lora_path in self.config.lora_path: for lora_config in self.config.lora_configs:
lora_path = lora_config["path"]
strength = lora_config.get("strength", 1.0)
lora_name = lora_wrapper.load_lora(lora_path) lora_name = lora_wrapper.load_lora(lora_path)
lora_wrapper.apply_lora(lora_name, self.config.strength_model) lora_wrapper.apply_lora(lora_name, strength)
logger.info(f"Loaded LoRA: {lora_name}") logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")
else: else:
model = WanCausVidModel(self.config.model_path, self.config, self.init_device) model = WanCausVidModel(self.config.model_path, self.config, self.init_device)
return model return model
......
...@@ -24,17 +24,19 @@ class WanDistillRunner(WanRunner): ...@@ -24,17 +24,19 @@ class WanDistillRunner(WanRunner):
super().__init__(config) super().__init__(config)
def load_transformer(self): def load_transformer(self):
if self.config.lora_path: if self.config.get("lora_configs") and self.config.lora_configs:
model = WanModel( model = WanModel(
self.config.model_path, self.config.model_path,
self.config, self.config,
self.init_device, self.init_device,
) )
lora_wrapper = WanLoraWrapper(model) lora_wrapper = WanLoraWrapper(model)
for lora_path in self.config.lora_path: for lora_config in self.config.lora_configs:
lora_path = lora_config["path"]
strength = lora_config.get("strength", 1.0)
lora_name = lora_wrapper.load_lora(lora_path) lora_name = lora_wrapper.load_lora(lora_path)
lora_wrapper.apply_lora(lora_name, self.config.strength_model) lora_wrapper.apply_lora(lora_name, strength)
logger.info(f"Loaded LoRA: {lora_name}") logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")
else: else:
model = WanDistillModel(self.config.model_path, self.config, self.init_device) model = WanDistillModel(self.config.model_path, self.config, self.init_device)
return model return model
......
...@@ -36,13 +36,15 @@ class WanRunner(DefaultRunner): ...@@ -36,13 +36,15 @@ class WanRunner(DefaultRunner):
self.config, self.config,
self.init_device, self.init_device,
) )
if self.config.lora_path: if self.config.get("lora_configs") and self.config.lora_configs:
assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False) assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False)
lora_wrapper = WanLoraWrapper(model) lora_wrapper = WanLoraWrapper(model)
for lora_path in self.config.lora_path: for lora_config in self.config.lora_configs:
lora_path = lora_config["path"]
strength = lora_config.get("strength", 1.0)
lora_name = lora_wrapper.load_lora(lora_path) lora_name = lora_wrapper.load_lora(lora_path)
lora_wrapper.apply_lora(lora_name, self.config.strength_model) lora_wrapper.apply_lora(lora_name, strength)
logger.info(f"Loaded LoRA: {lora_name}") logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")
return model return model
def load_image_encoder(self): def load_image_encoder(self):
......
...@@ -186,6 +186,12 @@ class DistributedInferenceService: ...@@ -186,6 +186,12 @@ class DistributedInferenceService:
self.is_running = False self.is_running = False
def start_distributed_inference(self, args) -> bool: def start_distributed_inference(self, args) -> bool:
if hasattr(args, "lora_path") and args.lora_path:
args.lora_configs = [{"path": args.lora_path, "strength": getattr(args, "lora_strength", 1.0)}]
delattr(args, "lora_path")
if hasattr(args, "lora_strength"):
delattr(args, "lora_strength")
self.args = args self.args = args
if self.is_running: if self.is_running:
logger.warning("Distributed inference service is already running") logger.warning("Distributed inference service is already running")
......
...@@ -17,8 +17,7 @@ def get_default_config(): ...@@ -17,8 +17,7 @@ def get_default_config():
"teacache_thresh": 0.26, "teacache_thresh": 0.26,
"use_ret_steps": False, "use_ret_steps": False,
"use_bfloat16": True, "use_bfloat16": True,
"lora_path": None, "lora_configs": None, # List of dicts with 'path' and 'strength' keys
"strength_model": 1.0,
"mm_config": {}, "mm_config": {},
"use_prompt_enhancer": False, "use_prompt_enhancer": False,
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment