Commit 6347b21b authored by root's avatar root
Browse files

Support convert weight to diffusers.

parent aec90a0d
......@@ -7,6 +7,8 @@
"seed": 42,
"sample_guide_scale": 5,
"sample_shift": 5,
"enable_cfg": true,
"cpu_offload": false,
"mm_config": {
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl",
"weight_auto_quant": true
......
import os
import re
import glob
import json
import argparse
import torch
from safetensors import safe_open, torch as st
from loguru import logger
from tqdm import tqdm
def get_key_mapping_rules(direction, model_type):
if model_type == "wan":
unified_rules = [
{"forward": (r"^head\.head$", "proj_out"), "backward": (r"^proj_out$", "head.head")},
{"forward": (r"^head\.modulation$", "scale_shift_table"), "backward": (r"^scale_shift_table$", "head.modulation")},
{"forward": (r"^text_embedding\.0\.", "condition_embedder.text_embedder.linear_1."), "backward": (r"^condition_embedder.text_embedder.linear_1\.", "text_embedding.0.")},
{"forward": (r"^text_embedding\.2\.", "condition_embedder.text_embedder.linear_2."), "backward": (r"^condition_embedder.text_embedder.linear_2\.", "text_embedding.2.")},
{"forward": (r"^time_embedding\.0\.", "condition_embedder.time_embedder.linear_1."), "backward": (r"^condition_embedder.time_embedder.linear_1\.", "time_embedding.0.")},
{"forward": (r"^time_embedding\.2\.", "condition_embedder.time_embedder.linear_2."), "backward": (r"^condition_embedder.time_embedder.linear_2\.", "time_embedding.2.")},
{"forward": (r"^time_projection\.1\.", "condition_embedder.time_proj."), "backward": (r"^condition_embedder.time_proj\.", "time_projection.1.")},
{"forward": (r"blocks\.(\d+)\.self_attn\.q\.", r"blocks.\1.attn1.to_q."), "backward": (r"blocks\.(\d+)\.attn1\.to_q\.", r"blocks.\1.self_attn.q.")},
{"forward": (r"blocks\.(\d+)\.self_attn\.k\.", r"blocks.\1.attn1.to_k."), "backward": (r"blocks\.(\d+)\.attn1\.to_k\.", r"blocks.\1.self_attn.k.")},
{"forward": (r"blocks\.(\d+)\.self_attn\.v\.", r"blocks.\1.attn1.to_v."), "backward": (r"blocks\.(\d+)\.attn1\.to_v\.", r"blocks.\1.self_attn.v.")},
{"forward": (r"blocks\.(\d+)\.self_attn\.o\.", r"blocks.\1.attn1.to_out.0."), "backward": (r"blocks\.(\d+)\.attn1\.to_out\.0\.", r"blocks.\1.self_attn.o.")},
{"forward": (r"blocks\.(\d+)\.cross_attn\.q\.", r"blocks.\1.attn2.to_q."), "backward": (r"blocks\.(\d+)\.attn2\.to_q\.", r"blocks.\1.cross_attn.q.")},
{"forward": (r"blocks\.(\d+)\.cross_attn\.k\.", r"blocks.\1.attn2.to_k."), "backward": (r"blocks\.(\d+)\.attn2\.to_k\.", r"blocks.\1.cross_attn.k.")},
{"forward": (r"blocks\.(\d+)\.cross_attn\.v\.", r"blocks.\1.attn2.to_v."), "backward": (r"blocks\.(\d+)\.attn2\.to_v\.", r"blocks.\1.cross_attn.v.")},
{"forward": (r"blocks\.(\d+)\.cross_attn\.o\.", r"blocks.\1.attn2.to_out.0."), "backward": (r"blocks\.(\d+)\.attn2\.to_out\.0\.", r"blocks.\1.cross_attn.o.")},
{"forward": (r"blocks\.(\d+)\.norm3\.", r"blocks.\1.norm2."), "backward": (r"blocks\.(\d+)\.norm2\.", r"blocks.\1.norm3.")},
{"forward": (r"blocks\.(\d+)\.ffn\.0\.", r"blocks.\1.ffn.net.0.proj."), "backward": (r"blocks\.(\d+)\.ffn\.net\.0\.proj\.", r"blocks.\1.ffn.0.")},
{"forward": (r"blocks\.(\d+)\.ffn\.2\.", r"blocks.\1.ffn.net.2."), "backward": (r"blocks\.(\d+)\.ffn\.net\.2\.", r"blocks.\1.ffn.2.")},
{"forward": (r"blocks\.(\d+)\.modulation\.", r"blocks.\1.scale_shift_table."), "backward": (r"blocks\.(\d+)\.scale_shift_table(?=\.|$)", r"blocks.\1.modulation")},
{"forward": (r"blocks\.(\d+)\.cross_attn\.k_img\.", r"blocks.\1.attn2.add_k_proj."), "backward": (r"blocks\.(\d+)\.attn2\.add_k_proj\.", r"blocks.\1.cross_attn.k_img.")},
{"forward": (r"blocks\.(\d+)\.cross_attn\.v_img\.", r"blocks.\1.attn2.add_v_proj."), "backward": (r"blocks\.(\d+)\.attn2\.add_v_proj\.", r"blocks.\1.cross_attn.v_img.")},
{
"forward": (r"blocks\.(\d+)\.cross_attn\.norm_k_img\.weight", r"blocks.\1.attn2.norm_added_k.weight"),
"backward": (r"blocks\.(\d+)\.attn2\.norm_added_k\.weight", r"blocks.\1.cross_attn.norm_k_img.weight"),
},
{"forward": (r"img_emb\.proj\.0\.", r"condition_embedder.image_embedder.norm1."), "backward": (r"condition_embedder\.image_embedder\.norm1\.", r"img_emb.proj.0.")},
{"forward": (r"img_emb\.proj\.1\.", r"condition_embedder.image_embedder.ff.net.0.proj."), "backward": (r"condition_embedder\.image_embedder\.ff\.net\.0\.proj\.", r"img_emb.proj.1.")},
{"forward": (r"img_emb\.proj\.3\.", r"condition_embedder.image_embedder.ff.net.2."), "backward": (r"condition_embedder\.image_embedder\.ff\.net\.2\.", r"img_emb.proj.3.")},
{"forward": (r"img_emb\.proj\.4\.", r"condition_embedder.image_embedder.norm2."), "backward": (r"condition_embedder\.image_embedder\.norm2\.", r"img_emb.proj.4.")},
{"forward": (r"blocks\.(\d+)\.self_attn\.norm_q\.weight", r"blocks.\1.attn1.norm_q.weight"), "backward": (r"blocks\.(\d+)\.attn1\.norm_q\.weight", r"blocks.\1.self_attn.norm_q.weight")},
{"forward": (r"blocks\.(\d+)\.self_attn\.norm_k\.weight", r"blocks.\1.attn1.norm_k.weight"), "backward": (r"blocks\.(\d+)\.attn1\.norm_k\.weight", r"blocks.\1.self_attn.norm_k.weight")},
{"forward": (r"blocks\.(\d+)\.cross_attn\.norm_q\.weight", r"blocks.\1.attn2.norm_q.weight"), "backward": (r"blocks\.(\d+)\.attn2\.norm_q\.weight", r"blocks.\1.cross_attn.norm_q.weight")},
{"forward": (r"blocks\.(\d+)\.cross_attn\.norm_k\.weight", r"blocks.\1.attn2.norm_k.weight"), "backward": (r"blocks\.(\d+)\.attn2\.norm_k\.weight", r"blocks.\1.cross_attn.norm_k.weight")},
# head projection mapping
{"forward": (r"^head\.head\.", "proj_out."), "backward": (r"^proj_out\.", "head.head.")},
]
if direction == "forward":
return [rule["forward"] for rule in unified_rules]
elif direction == "backward":
return [rule["backward"] for rule in unified_rules]
else:
raise ValueError(f"Invalid direction: {direction}")
else:
raise ValueError(f"Unsupported model type: {model_type}")
def convert_weights(args):
if os.path.isdir(args.source):
src_files = glob.glob(os.path.join(args.source, "*.safetensors"), recursive=True)
elif args.source.endswith((".pth", ".safetensors", "pt")):
src_files = [args.source]
else:
raise ValueError("Invalid input path")
merged_weights = {}
logger.info(f"Processing source files: {src_files}")
for file_path in tqdm(src_files, desc="Loading weights"):
logger.info(f"Loading weights from: {file_path}")
if file_path.endswith(".pt") or file_path.endswith(".pth"):
weights = torch.load(file_path, map_location="cpu", weights_only=True)
elif file_path.endswith(".safetensors"):
with safe_open(file_path, framework="pt") as f:
weights = {k: f.get_tensor(k) for k in f.keys()}
duplicate_keys = set(weights.keys()) & set(merged_weights.keys())
if duplicate_keys:
raise ValueError(f"Duplicate keys found: {duplicate_keys} in file {file_path}")
merged_weights.update(weights)
rules = get_key_mapping_rules(args.direction, args.model_type)
converted_weights = {}
logger.info("Converting keys...")
for key in tqdm(merged_weights.keys(), desc="Converting keys"):
new_key = key
for pattern, replacement in rules:
new_key = re.sub(pattern, replacement, new_key)
converted_weights[new_key] = merged_weights[key]
os.makedirs(args.output, exist_ok=True)
base_name = os.path.splitext(os.path.basename(args.source))[0] if args.source.endswith((".pth", ".safetensors")) else "converted_model"
index = {"metadata": {"total_size": 0}, "weight_map": {}}
chunk_idx = 0
current_chunk = {}
for idx, (k, v) in tqdm(enumerate(converted_weights.items()), desc="Saving chunks"):
current_chunk[k] = v
if (idx + 1) % args.chunk_size == 0 and args.chunk_size > 0:
output_filename = f"{base_name}_part{chunk_idx}.safetensors"
output_path = os.path.join(args.output, output_filename)
logger.info(f"Saving chunk to: {output_path}")
st.save_file(current_chunk, output_path)
for key in current_chunk:
index["weight_map"][key] = output_filename
index["metadata"]["total_size"] += os.path.getsize(output_path)
current_chunk = {}
chunk_idx += 1
if current_chunk:
output_filename = f"{base_name}_part{chunk_idx}.safetensors"
output_path = os.path.join(args.output, output_filename)
logger.info(f"Saving final chunk to: {output_path}")
st.save_file(current_chunk, output_path)
for key in current_chunk:
index["weight_map"][key] = output_filename
index["metadata"]["total_size"] += os.path.getsize(output_path)
# Save index file
index_path = os.path.join(args.output, "diffusion_pytorch_model.safetensors.index.json")
with open(index_path, "w", encoding="utf-8") as f:
json.dump(index, f, indent=2)
logger.info(f"Index file written to: {index_path}")
def main():
parser = argparse.ArgumentParser(description="Model weight format converter")
parser.add_argument("-s", "--source", required=True, help="Input path (file or directory)")
parser.add_argument("-o", "--output", required=True, help="Output directory path")
parser.add_argument("-d", "--direction", choices=["forward", "backward"], default="forward", help="Conversion direction: forward = 'lightx2v' -> 'Diffusers', backward = reverse")
parser.add_argument("-c", "--chunk-size", type=int, default=100, help="Chunk size for saving (only applies to forward), 0 = no chunking")
parser.add_argument("-t", "--model_type", choices=["wan"], default="wan", help="Model type")
args = parser.parse_args()
if os.path.isfile(args.output):
raise ValueError("Output path must be a directory, not a file")
logger.info("Starting model weight conversion...")
convert_weights(args)
logger.info(f"Conversion completed! Files saved to: {args.output}")
if __name__ == "__main__":
main()
......@@ -4,6 +4,7 @@ import psutil
import argparse
from fastapi import FastAPI, Request
from pydantic import BaseModel
from loguru import logger
import uvicorn
import json
import asyncio
......@@ -26,15 +27,15 @@ def kill_all_related_processes():
try:
child.kill()
except Exception as e:
print(f"Failed to kill child process {child.pid}: {e}")
logger.info(f"Failed to kill child process {child.pid}: {e}")
try:
current_process.kill()
except Exception as e:
print(f"Failed to kill main process: {e}")
logger.info(f"Failed to kill main process: {e}")
def signal_handler(sig, frame):
print("\nReceived Ctrl+C, shutting down all related processes...")
logger.info("\nReceived Ctrl+C, shutting down all related processes...")
kill_all_related_processes()
sys.exit(0)
......@@ -79,11 +80,11 @@ if __name__ == "__main__":
parser.add_argument("--config_json", type=str, required=True)
parser.add_argument("--port", type=int, default=8000)
args = parser.parse_args()
print(f"args: {args}")
logger.info(f"args: {args}")
with ProfilingContext("Init Server Cost"):
config = set_config(args)
print(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
runner = init_runner(config)
uvicorn.run(app, host="0.0.0.0", port=config.port, reload=False, workers=1)
from typing import Optional
from loguru import logger
import torch
import torch.distributed as dist
......@@ -21,7 +22,7 @@ class RingComm:
def send_recv(self, to_send: torch.Tensor, recv_tensor: Optional[torch.Tensor] = None) -> torch.Tensor:
if recv_tensor is None:
res = torch.empty_like(to_send)
# print(f"send_recv: empty_like {to_send.shape}")
# logger.info(f"send_recv: empty_like {to_send.shape}")
else:
res = recv_tensor
......
......@@ -2,6 +2,7 @@ import torch
import torch.distributed as dist
from lightx2v.attentions import attention
from lightx2v.utils.utils import seed_all
from loguru import logger
seed_all(42)
......@@ -65,10 +66,10 @@ def test_part_head():
# 验证结果一致性
if cur_rank == 0:
# import pdb; pdb.set_trace()
print("Outputs match:", torch.allclose(single_gpu_output, combined_output, rtol=1e-3, atol=1e-3))
logger.info("Outputs match:", torch.allclose(single_gpu_output, combined_output, rtol=1e-3, atol=1e-3))
# # 验证结果一致性
# print("Outputs match:", torch.allclose(single_gpu_output, combined_output, rtol=1e-3, atol=1e-3))
# logger.info("Outputs match:", torch.allclose(single_gpu_output, combined_output, rtol=1e-3, atol=1e-3))
if __name__ == "__main__":
......
......@@ -20,6 +20,7 @@ import os
import tensorrt as trt
from .common_runtime import *
from loguru import logger
try:
# Sometimes python does not understand FileNotFoundError
......@@ -67,11 +68,11 @@ def find_sample_data(description="Runs a TensorRT Python sample", subfolder="",
data_path = os.path.join(data_dir, subfolder)
if not os.path.exists(data_path):
if data_dir != kDEFAULT_DATA_ROOT:
print("WARNING: " + data_path + " does not exist. Trying " + data_dir + " instead.")
logger.info("WARNING: " + data_path + " does not exist. Trying " + data_dir + " instead.")
data_path = data_dir
# Make sure data directory exists.
if not (os.path.exists(data_path)) and data_dir != kDEFAULT_DATA_ROOT:
print("WARNING: {:} does not exist. Please provide the correct data path with the -d option.".format(data_path))
logger.info("WARNING: {:} does not exist. Please provide the correct data path with the -d option.".format(data_path))
return data_path
data_paths = [get_data_path(data_dir) for data_dir in args.datadir]
......
......@@ -4,6 +4,7 @@ from vllm import _custom_ops as ops
import sgl_kernel
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
from lightx2v.utils.quant_utils import IntegerQuantizer, FloatQuantizer
from loguru import logger
try:
import q8_kernels.functional as Q8F
......@@ -461,7 +462,7 @@ if __name__ == "__main__":
mm_weight.load(weight_dict)
input_tensor = torch.randn(1024, 4096).to(torch.bfloat16).cuda()
output_tensor = mm_weight.apply(input_tensor)
print(output_tensor.shape)
logger.info(output_tensor.shape)
weight_dict = {
"xx.weight": torch.randn(8192, 4096),
......@@ -473,7 +474,7 @@ if __name__ == "__main__":
mm_weight.load(weight_dict)
input_tensor = torch.randn(1024, 4096).to(torch.bfloat16).cuda()
output_tensor = mm_weight.apply(input_tensor)
print(output_tensor.shape)
logger.info(output_tensor.shape)
weight_dict = {
"xx.weight": torch.randn(8192, 4096),
......@@ -485,4 +486,4 @@ if __name__ == "__main__":
mm_weight.load(weight_dict)
input_tensor = torch.randn(1024, 4096).to(torch.bfloat16).cuda()
output_tensor = mm_weight.apply(input_tensor)
print(output_tensor.shape)
logger.info(output_tensor.shape)
......@@ -15,6 +15,7 @@ from lightx2v.models.runners.wan.wan_causal_runner import WanCausalRunner
from lightx2v.models.runners.graph_runner import GraphRunner
from lightx2v.common.ops import *
from loguru import logger
def init_runner(config):
......@@ -37,17 +38,17 @@ if __name__ == "__main__":
parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--config_json", type=str, required=True)
parser.add_argument("--enable_cfg", type=bool, default=False)
parser.add_argument("--prompt", type=str, required=True)
parser.add_argument("--negative_prompt", type=str, default="")
parser.add_argument("--image_path", type=str, default="", help="The path to input image file or path for image-to-video (i2v) task")
parser.add_argument("--save_video_path", type=str, default="./output_lightx2v.mp4", help="The path to save video path/file")
args = parser.parse_args()
print(f"args: {args}")
logger.info(f"args: {args}")
with ProfilingContext("Total Cost"):
config = set_config(args)
print(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
runner = init_runner(config)
runner.run_pipeline()
import torch
from transformers import CLIPTextModel, AutoTokenizer
from loguru import logger
class TextEncoderHFClipModel:
......@@ -54,4 +55,4 @@ if __name__ == "__main__":
model = TextEncoderHFClipModel(model_path, torch.device("cuda"))
text = "A cat walks on the grass, realistic style."
outputs = model.infer(text)
print(outputs)
logger.info(outputs)
import torch
from transformers import AutoModel, AutoTokenizer
from loguru import logger
class TextEncoderHFLlamaModel:
......@@ -67,4 +68,4 @@ if __name__ == "__main__":
model = TextEncoderHFLlamaModel(model_path, torch.device("cuda"))
text = "A cat walks on the grass, realistic style."
outputs = model.infer(text)
print(outputs)
logger.info(outputs)
......@@ -3,6 +3,7 @@ from PIL import Image
import numpy as np
import torchvision.transforms as transforms
from transformers import LlavaForConditionalGeneration, CLIPImageProcessor, AutoTokenizer
from loguru import logger
def generate_crop_size_list(base_size=256, patch_size=32, max_ratio=4.0):
......@@ -158,4 +159,4 @@ if __name__ == "__main__":
img_path = "/mtc/yongyang/projects/lightx2v/assets/inputs/imgs/img_1.jpg"
img = Image.open(img_path).convert("RGB")
outputs = model.infer(text, img, None)
print(outputs)
logger.info(outputs)
......@@ -8,6 +8,7 @@ import torch.nn as nn
import torch.nn.functional as F
from .tokenizer import HuggingfaceTokenizer
from loguru import logger
__all__ = [
"T5Model",
......@@ -522,4 +523,4 @@ if __name__ == "__main__":
)
text = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
outputs = model.infer(text)
print(outputs)
logger.info(outputs)
......@@ -10,6 +10,7 @@ import torchvision.transforms as T
from lightx2v.attentions import attention
from lightx2v.models.input_encoders.hf.t5.tokenizer import HuggingfaceTokenizer
from loguru import logger
from .xlm_roberta import XLMRoberta
......@@ -190,7 +191,7 @@ class VisionTransformer(nn.Module):
norm_eps=1e-5,
):
if image_size % patch_size != 0:
print("[WARNING] image_size is not divisible by patch_size", flush=True)
logger.info("[WARNING] image_size is not divisible by patch_size", flush=True)
assert pool_type in ("token", "token_fc", "attn_pool")
out_dim = out_dim or dim
super().__init__()
......
......@@ -123,18 +123,20 @@ class WanModel:
self.scheduler.cnt += 1
if self.scheduler.cnt >= self.scheduler.num_steps:
self.scheduler.cnt = 0
self.scheduler.noise_pred = noise_pred_cond
embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=False)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_uncond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
if self.config["enable_cfg"]:
embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=False)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_uncond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
if self.config["feature_caching"] == "Tea":
self.scheduler.cnt += 1
if self.scheduler.cnt >= self.scheduler.num_steps:
self.scheduler.cnt = 0
if self.config["feature_caching"] == "Tea":
self.scheduler.cnt += 1
if self.scheduler.cnt >= self.scheduler.num_steps:
self.scheduler.cnt = 0
self.scheduler.noise_pred = noise_pred_uncond + self.config.sample_guide_scale * (noise_pred_cond - noise_pred_uncond)
self.scheduler.noise_pred = noise_pred_uncond + self.config.sample_guide_scale * (noise_pred_cond - noise_pred_uncond)
if self.config["cpu_offload"]:
self.pre_weight.to_cpu()
self.post_weight.to_cpu()
if self.config["cpu_offload"]:
self.pre_weight.to_cpu()
self.post_weight.to_cpu()
......@@ -4,6 +4,7 @@ import torch.distributed as dist
from lightx2v.utils.profiler import ProfilingContext4Debug, ProfilingContext
from lightx2v.utils.utils import save_videos_grid, cache_video
from lightx2v.utils.envs import *
from loguru import logger
class DefaultRunner:
......@@ -32,7 +33,7 @@ class DefaultRunner:
def run(self):
for step_index in range(self.model.scheduler.infer_steps):
print(f"==> step_index: {step_index + 1} / {self.model.scheduler.infer_steps}")
logger.info(f"==> step_index: {step_index + 1} / {self.model.scheduler.infer_steps}")
with ProfilingContext4Debug("step_pre"):
self.model.scheduler.step_pre(step_index=step_index)
......
from lightx2v.utils.profiler import ProfilingContext4Debug
from loguru import logger
class GraphRunner:
......@@ -7,10 +8,10 @@ class GraphRunner:
self.compile()
def compile(self):
print("start compile...")
logger.info("start compile...")
with ProfilingContext4Debug("compile"):
self.runner.run_step()
print("end compile...")
logger.info("end compile...")
def run_pipeline(self):
return self.runner.run_pipeline()
......@@ -14,6 +14,7 @@ from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel
from lightx2v.models.networks.wan.causal_model import WanCausalModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
from loguru import logger
import torch.distributed as dist
......@@ -54,7 +55,7 @@ class WanCausalRunner(WanRunner):
lora_wrapper = WanLoraWrapper(model)
lora_name = lora_wrapper.load_lora(self.config.lora_path)
lora_wrapper.apply_lora(lora_name, self.config.strength_model)
print(f"Loaded LoRA: {lora_name}")
logger.info(f"Loaded LoRA: {lora_name}")
vae_model = WanVAE(vae_pth=os.path.join(self.config.model_path, "Wan2.1_VAE.pth"), device=init_device, parallel=self.config.parallel_vae)
if self.config.task == "i2v":
......@@ -95,13 +96,13 @@ class WanCausalRunner(WanRunner):
start_block_idx = 0
for fragment_idx in range(self.num_fragments):
print(f"=======> fragment_idx: {fragment_idx + 1} / {self.num_fragments}")
logger.info(f"=======> fragment_idx: {fragment_idx + 1} / {self.num_fragments}")
kv_start = 0
kv_end = kv_start + self.num_frame_per_block * self.frame_seq_length
if fragment_idx > 0:
print("recompute the kv_cache ...")
logger.info("recompute the kv_cache ...")
with ProfilingContext4Debug("step_pre"):
self.model.scheduler.latents = self.model.scheduler.last_sample
self.model.scheduler.step_pre(step_index=self.model.scheduler.infer_steps - 1)
......@@ -115,12 +116,12 @@ class WanCausalRunner(WanRunner):
infer_blocks = self.infer_blocks - (fragment_idx > 0)
for block_idx in range(infer_blocks):
print(f"=======> block_idx: {block_idx + 1} / {infer_blocks}")
print(f"=======> kv_start: {kv_start}, kv_end: {kv_end}")
logger.info(f"=======> block_idx: {block_idx + 1} / {infer_blocks}")
logger.info(f"=======> kv_start: {kv_start}, kv_end: {kv_end}")
self.model.scheduler.reset()
for step_index in range(self.model.scheduler.infer_steps):
print(f"==> step_index: {step_index + 1} / {self.model.scheduler.infer_steps}")
logger.info(f"==> step_index: {step_index + 1} / {self.model.scheduler.infer_steps}")
with ProfilingContext4Debug("step_pre"):
self.model.scheduler.step_pre(step_index=step_index)
......
......@@ -14,6 +14,7 @@ from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
import torch.distributed as dist
from loguru import logger
@RUNNER_REGISTER("wan2.1")
......@@ -47,7 +48,7 @@ class WanRunner(DefaultRunner):
lora_wrapper = WanLoraWrapper(model)
lora_name = lora_wrapper.load_lora(self.config.lora_path)
lora_wrapper.apply_lora(lora_name, self.config.strength_model)
print(f"Loaded LoRA: {lora_name}")
logger.info(f"Loaded LoRA: {lora_name}")
vae_model = WanVAE(vae_pth=os.path.join(self.config.model_path, "Wan2.1_VAE.pth"), device=init_device, parallel=self.config.parallel_vae)
if self.config.task == "i2v":
......
......@@ -7,6 +7,7 @@ import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from einops import rearrange
from loguru import logger
__all__ = [
"WanVAE",
......@@ -801,7 +802,7 @@ class WanVAE:
split_dim = 2
images = self.decode_dist(zs, world_size, cur_rank, split_dim)
else:
print("Fall back to naive decode mode")
logger.info("Fall back to naive decode mode")
images = self.model.decode(zs.unsqueeze(0), self.scale).float().clamp_(-1, 1)
else:
images = self.model.decode(zs.unsqueeze(0), self.scale).float().clamp_(-1, 1)
......
......@@ -2,6 +2,7 @@ import time
import torch
from contextlib import ContextDecorator
from lightx2v.utils.envs import *
from loguru import logger
class _ProfilingContext(ContextDecorator):
......@@ -16,7 +17,7 @@ class _ProfilingContext(ContextDecorator):
def __exit__(self, exc_type, exc_val, exc_tb):
torch.cuda.synchronize()
elapsed = time.perf_counter() - self.start_time
print(f"[Profile] {self.name} cost {elapsed:.6f} seconds")
logger.info(f"[Profile] {self.name} cost {elapsed:.6f} seconds")
return False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment