Commit a8ad2d7d authored by lijiaqi2's avatar lijiaqi2 Committed by gaopeng
Browse files

feat: support LoRA

parent 6c18f54c
import argparse import argparse
from contextlib import contextmanager
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import os import os
...@@ -21,6 +22,8 @@ from lightx2v.text2v.models.schedulers.wan.feature_caching.scheduler import WanS ...@@ -21,6 +22,8 @@ from lightx2v.text2v.models.schedulers.wan.feature_caching.scheduler import WanS
from lightx2v.text2v.models.networks.hunyuan.model import HunyuanModel from lightx2v.text2v.models.networks.hunyuan.model import HunyuanModel
from lightx2v.text2v.models.networks.wan.model import WanModel from lightx2v.text2v.models.networks.wan.model import WanModel
from lightx2v.text2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.text2v.models.video_encoders.hf.autoencoder_kl_causal_3d.model import VideoEncoderKLCausal3DModel from lightx2v.text2v.models.video_encoders.hf.autoencoder_kl_causal_3d.model import VideoEncoderKLCausal3DModel
from lightx2v.text2v.models.video_encoders.hf.wan.vae import WanVAE from lightx2v.text2v.models.video_encoders.hf.wan.vae import WanVAE
...@@ -29,6 +32,14 @@ from lightx2v.common.ops import * ...@@ -29,6 +32,14 @@ from lightx2v.common.ops import *
from lightx2v.image2v.models.wan.model import CLIPModel from lightx2v.image2v.models.wan.model import CLIPModel
@contextmanager
def time_duration(label: str = ""):
start_time = time.time()
yield
end_time = time.time()
print(f"==> {label} start:{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))} cost {end_time - start_time:.2f} seconds")
def load_models(args, model_config): def load_models(args, model_config):
if model_config["parallel_attn_type"]: if model_config["parallel_attn_type"]:
cur_rank = dist.get_rank() # 获取当前进程的 rank cur_rank = dist.get_rank() # 获取当前进程的 rank
...@@ -59,15 +70,27 @@ def load_models(args, model_config): ...@@ -59,15 +70,27 @@ def load_models(args, model_config):
shard_fn=None, shard_fn=None,
) )
text_encoders = [text_encoder] text_encoders = [text_encoder]
model = WanModel(args.model_path, model_config, init_device)
vae_model = WanVAE(vae_pth=os.path.join(args.model_path, "Wan2.1_VAE.pth"), device=init_device, parallel=args.parallel_vae) with time_duration("Load Wan Model"):
model = WanModel(args.model_path, model_config, init_device)
if args.lora_path:
lora_wrapper = WanLoraWrapper(model)
with time_duration("Load LoRA Model"):
lora_name = lora_wrapper.load_lora(args.lora_path)
lora_wrapper.apply_lora(lora_name, args.strength_model)
print(f"Loaded LoRA: {lora_name}")
with time_duration("Load WAN VAE Model"):
vae_model = WanVAE(vae_pth=os.path.join(args.model_path, "Wan2.1_VAE.pth"), device=init_device, parallel=args.parallel_vae)
if args.task == "i2v": if args.task == "i2v":
image_encoder = CLIPModel( with time_duration("Load Image Encoder"):
dtype=torch.float16, image_encoder = CLIPModel(
device=init_device, dtype=torch.float16,
checkpoint_path=os.path.join(args.model_path, "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"), device=init_device,
tokenizer_path=os.path.join(args.model_path, "xlm-roberta-large"), checkpoint_path=os.path.join(args.model_path, "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"),
) tokenizer_path=os.path.join(args.model_path, "xlm-roberta-large"),
)
else: else:
raise NotImplementedError(f"Unsupported model class: {args.model_cls}") raise NotImplementedError(f"Unsupported model class: {args.model_cls}")
...@@ -312,6 +335,10 @@ if __name__ == "__main__": ...@@ -312,6 +335,10 @@ if __name__ == "__main__":
parser.add_argument("--patch_size", default=(1, 2, 2)) parser.add_argument("--patch_size", default=(1, 2, 2))
parser.add_argument("--teacache_thresh", type=float, default=0.26) parser.add_argument("--teacache_thresh", type=float, default=0.26)
parser.add_argument("--use_ret_steps", action="store_true", default=False) parser.add_argument("--use_ret_steps", action="store_true", default=False)
parser.add_argument("--use_bfloat16", action="store_true", default=True)
parser.add_argument("--lora_path", type=str, default=None)
parser.add_argument("--strength_model", type=float, default=1.0)
args = parser.parse_args() args = parser.parse_args()
start_time = time.time() start_time = time.time()
...@@ -338,6 +365,7 @@ if __name__ == "__main__": ...@@ -338,6 +365,7 @@ if __name__ == "__main__":
"feature_caching": args.feature_caching, "feature_caching": args.feature_caching,
"parallel_attn_type": args.parallel_attn_type, "parallel_attn_type": args.parallel_attn_type,
"parallel_vae": args.parallel_vae, "parallel_vae": args.parallel_vae,
"use_bfloat16": args.use_bfloat16,
} }
if args.config_path is not None: if args.config_path is not None:
...@@ -347,10 +375,8 @@ if __name__ == "__main__": ...@@ -347,10 +375,8 @@ if __name__ == "__main__":
print(f"model_config: {model_config}") print(f"model_config: {model_config}")
model, text_encoders, vae_model, image_encoder = load_models(args, model_config) with time_duration("Load models"):
model, text_encoders, vae_model, image_encoder = load_models(args, model_config)
load_models_time = time.time()
print(f"Load models cost: {load_models_time - start_time}")
if args.task in ["i2v"]: if args.task in ["i2v"]:
image_encoder_output = run_image_encoder(args, image_encoder, vae_model) image_encoder_output = run_image_encoder(args, image_encoder, vae_model)
......
import os
import torch
from safetensors import safe_open
from loguru import logger
import gc
class WanLoraWrapper:
def __init__(self, wan_model):
self.model = wan_model
self.lora_dict = {}
self.override_dict = {}
def load_lora(self, lora_path, lora_name=None):
if lora_name is None:
lora_name = os.path.basename(lora_path).split(".")[0]
if lora_name in self.lora_dict:
logger.info(f"LoRA {lora_name} already loaded, skipping...")
return lora_name
lora_weights = self._load_lora_file(lora_path)
self.lora_dict[lora_name] = lora_weights
return lora_name
def _load_lora_file(self, file_path):
use_bfloat16 = True # Default value
if self.model.config and hasattr(self.model.config, "get"):
use_bfloat16 = self.model.config.get("use_bfloat16", True)
with safe_open(file_path, framework="pt") as f:
if use_bfloat16:
tensor_dict = {key: f.get_tensor(key).to(torch.bfloat16) for key in f.keys()}
else:
tensor_dict = {key: f.get_tensor(key) for key in f.keys()}
return tensor_dict
def apply_lora(self, lora_name, alpha=1.0):
if lora_name not in self.lora_dict:
logger.info(f"LoRA {lora_name} not found. Please load it first.")
if hasattr(self.model, "current_lora") and self.model.current_lora:
self.remove_lora()
if not hasattr(self.model, "original_weight_dict"):
logger.error("Model does not have 'original_weight_dict'. Cannot apply LoRA.")
return False
weight_dict = self.model.original_weight_dict
lora_weights = self.lora_dict[lora_name]
self._apply_lora_weights(weight_dict, lora_weights, alpha)
# 重新加载权重
self.model.pre_weight.load_weights(weight_dict)
self.model.post_weight.load_weights(weight_dict)
self.model.transformer_weights.load_weights(weight_dict)
self.model.current_lora = lora_name
logger.info(f"Applied LoRA: {lora_name} with alpha={alpha}")
return True
def _apply_lora_weights(self, weight_dict, lora_weights, alpha):
lora_pairs = {}
prefix = "diffusion_model."
for key in lora_weights.keys():
if key.endswith("lora_A.weight") and key.startswith(prefix):
base_name = key[len(prefix) :].replace("lora_A.weight", "weight")
b_key = key.replace("lora_A.weight", "lora_B.weight")
if b_key in lora_weights:
lora_pairs[base_name] = (key, b_key)
applied_count = 0
for name, param in weight_dict.items():
if name in lora_pairs:
name_lora_A, name_lora_B = lora_pairs[name]
lora_A = lora_weights[name_lora_A].to(param.device, param.dtype)
lora_B = lora_weights[name_lora_B].to(param.device, param.dtype)
param += torch.matmul(lora_B, lora_A) * alpha
applied_count += 1
logger.info(f"Applied {applied_count} LoRA weight adjustments")
if applied_count == 0:
logger.info(
"Warning: No LoRA weights were applied. Expected naming conventions: 'diffusion_model.<layer_name>.lora_A.weight' and 'diffusion_model.<layer_name>.lora_B.weight'. Please verify the LoRA weight file."
)
def remove_lora(self):
if not self.model.current_lora:
logger.info("No LoRA currently applied")
return
logger.info(f"Removing LoRA {self.model.current_lora}...")
restored_count = 0
for k, v in self.override_dict.items():
self.model.original_weight_dict[k] = v.to(self.model.device)
restored_count += 1
logger.info(f"LoRA {self.model.current_lora} removed, restored {restored_count} weights")
self.model.pre_weight.load_weights(self.model.original_weight_dict)
self.model.post_weight.load_weights(self.model.original_weight_dict)
self.model.transformer_weights.load_weights(self.model.original_weight_dict)
if self.model.current_lora and self.model.current_lora in self.lora_dict:
del self.lora_dict[self.model.current_lora]
self.override_dict = {}
torch.cuda.empty_cache()
gc.collect()
def list_loaded_loras(self):
return list(self.lora_dict.keys())
def get_current_lora(self):
return self.model.current_lora
...@@ -30,6 +30,7 @@ class WanModel: ...@@ -30,6 +30,7 @@ class WanModel:
self._init_infer_class() self._init_infer_class()
self._init_weights() self._init_weights()
self._init_infer() self._init_infer()
self.current_lora = None
if config["parallel_attn_type"]: if config["parallel_attn_type"]:
if config["parallel_attn_type"] == "ulysses": if config["parallel_attn_type"] == "ulysses":
...@@ -53,8 +54,12 @@ class WanModel: ...@@ -53,8 +54,12 @@ class WanModel:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}") raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
def _load_safetensor_to_dict(self, file_path): def _load_safetensor_to_dict(self, file_path):
use_bfloat16 = self.config.get("use_bfloat16", True)
with safe_open(file_path, framework="pt") as f: with safe_open(file_path, framework="pt") as f:
tensor_dict = {key: f.get_tensor(key).to(torch.bfloat16).to(self.device) for key in f.keys()} if use_bfloat16:
tensor_dict = {key: f.get_tensor(key).to(torch.bfloat16).to(self.device) for key in f.keys()}
else:
tensor_dict = {key: f.get_tensor(key).to(self.device) for key in f.keys()}
return tensor_dict return tensor_dict
def _load_ckpt(self): def _load_ckpt(self):
...@@ -69,16 +74,19 @@ class WanModel: ...@@ -69,16 +74,19 @@ class WanModel:
weight_dict.update(file_weights) weight_dict.update(file_weights)
return weight_dict return weight_dict
def _init_weights(self): def _init_weights(self, weight_dict=None):
weight_dict = self._load_ckpt() if weight_dict is None:
self.original_weight_dict = self._load_ckpt()
else:
self.original_weight_dict = weight_dict
# init weights # init weights
self.pre_weight = self.pre_weight_class(self.config) self.pre_weight = self.pre_weight_class(self.config)
self.post_weight = self.post_weight_class(self.config) self.post_weight = self.post_weight_class(self.config)
self.transformer_weights = self.transformer_weight_class(self.config) self.transformer_weights = self.transformer_weight_class(self.config)
# load weights # load weights
self.pre_weight.load_weights(weight_dict) self.pre_weight.load_weights(self.original_weight_dict)
self.post_weight.load_weights(weight_dict) self.post_weight.load_weights(self.original_weight_dict)
self.transformer_weights.load_weights(weight_dict) self.transformer_weights.load_weights(self.original_weight_dict)
def _init_infer(self): def _init_infer(self):
self.pre_infer = self.pre_infer_class(self.config) self.pre_infer = self.pre_infer_class(self.config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment