"git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "66301e124f19099ceef3023494551917fb67da83"
Commit 6347b21b authored by root's avatar root
Browse files

Support convert weight to diffusers.

parent aec90a0d
...@@ -7,6 +7,8 @@ ...@@ -7,6 +7,8 @@
"seed": 42, "seed": 42,
"sample_guide_scale": 5, "sample_guide_scale": 5,
"sample_shift": 5, "sample_shift": 5,
"enable_cfg": true,
"cpu_offload": false,
"mm_config": { "mm_config": {
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl", "mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl",
"weight_auto_quant": true "weight_auto_quant": true
......
import os
import re
import glob
import json
import argparse
import torch
from safetensors import safe_open, torch as st
from loguru import logger
from tqdm import tqdm
def get_key_mapping_rules(direction, model_type):
if model_type == "wan":
unified_rules = [
{"forward": (r"^head\.head$", "proj_out"), "backward": (r"^proj_out$", "head.head")},
{"forward": (r"^head\.modulation$", "scale_shift_table"), "backward": (r"^scale_shift_table$", "head.modulation")},
{"forward": (r"^text_embedding\.0\.", "condition_embedder.text_embedder.linear_1."), "backward": (r"^condition_embedder.text_embedder.linear_1\.", "text_embedding.0.")},
{"forward": (r"^text_embedding\.2\.", "condition_embedder.text_embedder.linear_2."), "backward": (r"^condition_embedder.text_embedder.linear_2\.", "text_embedding.2.")},
{"forward": (r"^time_embedding\.0\.", "condition_embedder.time_embedder.linear_1."), "backward": (r"^condition_embedder.time_embedder.linear_1\.", "time_embedding.0.")},
{"forward": (r"^time_embedding\.2\.", "condition_embedder.time_embedder.linear_2."), "backward": (r"^condition_embedder.time_embedder.linear_2\.", "time_embedding.2.")},
{"forward": (r"^time_projection\.1\.", "condition_embedder.time_proj."), "backward": (r"^condition_embedder.time_proj\.", "time_projection.1.")},
{"forward": (r"blocks\.(\d+)\.self_attn\.q\.", r"blocks.\1.attn1.to_q."), "backward": (r"blocks\.(\d+)\.attn1\.to_q\.", r"blocks.\1.self_attn.q.")},
{"forward": (r"blocks\.(\d+)\.self_attn\.k\.", r"blocks.\1.attn1.to_k."), "backward": (r"blocks\.(\d+)\.attn1\.to_k\.", r"blocks.\1.self_attn.k.")},
{"forward": (r"blocks\.(\d+)\.self_attn\.v\.", r"blocks.\1.attn1.to_v."), "backward": (r"blocks\.(\d+)\.attn1\.to_v\.", r"blocks.\1.self_attn.v.")},
{"forward": (r"blocks\.(\d+)\.self_attn\.o\.", r"blocks.\1.attn1.to_out.0."), "backward": (r"blocks\.(\d+)\.attn1\.to_out\.0\.", r"blocks.\1.self_attn.o.")},
{"forward": (r"blocks\.(\d+)\.cross_attn\.q\.", r"blocks.\1.attn2.to_q."), "backward": (r"blocks\.(\d+)\.attn2\.to_q\.", r"blocks.\1.cross_attn.q.")},
{"forward": (r"blocks\.(\d+)\.cross_attn\.k\.", r"blocks.\1.attn2.to_k."), "backward": (r"blocks\.(\d+)\.attn2\.to_k\.", r"blocks.\1.cross_attn.k.")},
{"forward": (r"blocks\.(\d+)\.cross_attn\.v\.", r"blocks.\1.attn2.to_v."), "backward": (r"blocks\.(\d+)\.attn2\.to_v\.", r"blocks.\1.cross_attn.v.")},
{"forward": (r"blocks\.(\d+)\.cross_attn\.o\.", r"blocks.\1.attn2.to_out.0."), "backward": (r"blocks\.(\d+)\.attn2\.to_out\.0\.", r"blocks.\1.cross_attn.o.")},
{"forward": (r"blocks\.(\d+)\.norm3\.", r"blocks.\1.norm2."), "backward": (r"blocks\.(\d+)\.norm2\.", r"blocks.\1.norm3.")},
{"forward": (r"blocks\.(\d+)\.ffn\.0\.", r"blocks.\1.ffn.net.0.proj."), "backward": (r"blocks\.(\d+)\.ffn\.net\.0\.proj\.", r"blocks.\1.ffn.0.")},
{"forward": (r"blocks\.(\d+)\.ffn\.2\.", r"blocks.\1.ffn.net.2."), "backward": (r"blocks\.(\d+)\.ffn\.net\.2\.", r"blocks.\1.ffn.2.")},
{"forward": (r"blocks\.(\d+)\.modulation\.", r"blocks.\1.scale_shift_table."), "backward": (r"blocks\.(\d+)\.scale_shift_table(?=\.|$)", r"blocks.\1.modulation")},
{"forward": (r"blocks\.(\d+)\.cross_attn\.k_img\.", r"blocks.\1.attn2.add_k_proj."), "backward": (r"blocks\.(\d+)\.attn2\.add_k_proj\.", r"blocks.\1.cross_attn.k_img.")},
{"forward": (r"blocks\.(\d+)\.cross_attn\.v_img\.", r"blocks.\1.attn2.add_v_proj."), "backward": (r"blocks\.(\d+)\.attn2\.add_v_proj\.", r"blocks.\1.cross_attn.v_img.")},
{
"forward": (r"blocks\.(\d+)\.cross_attn\.norm_k_img\.weight", r"blocks.\1.attn2.norm_added_k.weight"),
"backward": (r"blocks\.(\d+)\.attn2\.norm_added_k\.weight", r"blocks.\1.cross_attn.norm_k_img.weight"),
},
{"forward": (r"img_emb\.proj\.0\.", r"condition_embedder.image_embedder.norm1."), "backward": (r"condition_embedder\.image_embedder\.norm1\.", r"img_emb.proj.0.")},
{"forward": (r"img_emb\.proj\.1\.", r"condition_embedder.image_embedder.ff.net.0.proj."), "backward": (r"condition_embedder\.image_embedder\.ff\.net\.0\.proj\.", r"img_emb.proj.1.")},
{"forward": (r"img_emb\.proj\.3\.", r"condition_embedder.image_embedder.ff.net.2."), "backward": (r"condition_embedder\.image_embedder\.ff\.net\.2\.", r"img_emb.proj.3.")},
{"forward": (r"img_emb\.proj\.4\.", r"condition_embedder.image_embedder.norm2."), "backward": (r"condition_embedder\.image_embedder\.norm2\.", r"img_emb.proj.4.")},
{"forward": (r"blocks\.(\d+)\.self_attn\.norm_q\.weight", r"blocks.\1.attn1.norm_q.weight"), "backward": (r"blocks\.(\d+)\.attn1\.norm_q\.weight", r"blocks.\1.self_attn.norm_q.weight")},
{"forward": (r"blocks\.(\d+)\.self_attn\.norm_k\.weight", r"blocks.\1.attn1.norm_k.weight"), "backward": (r"blocks\.(\d+)\.attn1\.norm_k\.weight", r"blocks.\1.self_attn.norm_k.weight")},
{"forward": (r"blocks\.(\d+)\.cross_attn\.norm_q\.weight", r"blocks.\1.attn2.norm_q.weight"), "backward": (r"blocks\.(\d+)\.attn2\.norm_q\.weight", r"blocks.\1.cross_attn.norm_q.weight")},
{"forward": (r"blocks\.(\d+)\.cross_attn\.norm_k\.weight", r"blocks.\1.attn2.norm_k.weight"), "backward": (r"blocks\.(\d+)\.attn2\.norm_k\.weight", r"blocks.\1.cross_attn.norm_k.weight")},
# head projection mapping
{"forward": (r"^head\.head\.", "proj_out."), "backward": (r"^proj_out\.", "head.head.")},
]
if direction == "forward":
return [rule["forward"] for rule in unified_rules]
elif direction == "backward":
return [rule["backward"] for rule in unified_rules]
else:
raise ValueError(f"Invalid direction: {direction}")
else:
raise ValueError(f"Unsupported model type: {model_type}")
def convert_weights(args):
if os.path.isdir(args.source):
src_files = glob.glob(os.path.join(args.source, "*.safetensors"), recursive=True)
elif args.source.endswith((".pth", ".safetensors", "pt")):
src_files = [args.source]
else:
raise ValueError("Invalid input path")
merged_weights = {}
logger.info(f"Processing source files: {src_files}")
for file_path in tqdm(src_files, desc="Loading weights"):
logger.info(f"Loading weights from: {file_path}")
if file_path.endswith(".pt") or file_path.endswith(".pth"):
weights = torch.load(file_path, map_location="cpu", weights_only=True)
elif file_path.endswith(".safetensors"):
with safe_open(file_path, framework="pt") as f:
weights = {k: f.get_tensor(k) for k in f.keys()}
duplicate_keys = set(weights.keys()) & set(merged_weights.keys())
if duplicate_keys:
raise ValueError(f"Duplicate keys found: {duplicate_keys} in file {file_path}")
merged_weights.update(weights)
rules = get_key_mapping_rules(args.direction, args.model_type)
converted_weights = {}
logger.info("Converting keys...")
for key in tqdm(merged_weights.keys(), desc="Converting keys"):
new_key = key
for pattern, replacement in rules:
new_key = re.sub(pattern, replacement, new_key)
converted_weights[new_key] = merged_weights[key]
os.makedirs(args.output, exist_ok=True)
base_name = os.path.splitext(os.path.basename(args.source))[0] if args.source.endswith((".pth", ".safetensors")) else "converted_model"
index = {"metadata": {"total_size": 0}, "weight_map": {}}
chunk_idx = 0
current_chunk = {}
for idx, (k, v) in tqdm(enumerate(converted_weights.items()), desc="Saving chunks"):
current_chunk[k] = v
if (idx + 1) % args.chunk_size == 0 and args.chunk_size > 0:
output_filename = f"{base_name}_part{chunk_idx}.safetensors"
output_path = os.path.join(args.output, output_filename)
logger.info(f"Saving chunk to: {output_path}")
st.save_file(current_chunk, output_path)
for key in current_chunk:
index["weight_map"][key] = output_filename
index["metadata"]["total_size"] += os.path.getsize(output_path)
current_chunk = {}
chunk_idx += 1
if current_chunk:
output_filename = f"{base_name}_part{chunk_idx}.safetensors"
output_path = os.path.join(args.output, output_filename)
logger.info(f"Saving final chunk to: {output_path}")
st.save_file(current_chunk, output_path)
for key in current_chunk:
index["weight_map"][key] = output_filename
index["metadata"]["total_size"] += os.path.getsize(output_path)
# Save index file
index_path = os.path.join(args.output, "diffusion_pytorch_model.safetensors.index.json")
with open(index_path, "w", encoding="utf-8") as f:
json.dump(index, f, indent=2)
logger.info(f"Index file written to: {index_path}")
def main():
parser = argparse.ArgumentParser(description="Model weight format converter")
parser.add_argument("-s", "--source", required=True, help="Input path (file or directory)")
parser.add_argument("-o", "--output", required=True, help="Output directory path")
parser.add_argument("-d", "--direction", choices=["forward", "backward"], default="forward", help="Conversion direction: forward = 'lightx2v' -> 'Diffusers', backward = reverse")
parser.add_argument("-c", "--chunk-size", type=int, default=100, help="Chunk size for saving (only applies to forward), 0 = no chunking")
parser.add_argument("-t", "--model_type", choices=["wan"], default="wan", help="Model type")
args = parser.parse_args()
if os.path.isfile(args.output):
raise ValueError("Output path must be a directory, not a file")
logger.info("Starting model weight conversion...")
convert_weights(args)
logger.info(f"Conversion completed! Files saved to: {args.output}")
if __name__ == "__main__":
main()
...@@ -4,6 +4,7 @@ import psutil ...@@ -4,6 +4,7 @@ import psutil
import argparse import argparse
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from pydantic import BaseModel from pydantic import BaseModel
from loguru import logger
import uvicorn import uvicorn
import json import json
import asyncio import asyncio
...@@ -26,15 +27,15 @@ def kill_all_related_processes(): ...@@ -26,15 +27,15 @@ def kill_all_related_processes():
try: try:
child.kill() child.kill()
except Exception as e: except Exception as e:
print(f"Failed to kill child process {child.pid}: {e}") logger.info(f"Failed to kill child process {child.pid}: {e}")
try: try:
current_process.kill() current_process.kill()
except Exception as e: except Exception as e:
print(f"Failed to kill main process: {e}") logger.info(f"Failed to kill main process: {e}")
def signal_handler(sig, frame): def signal_handler(sig, frame):
print("\nReceived Ctrl+C, shutting down all related processes...") logger.info("\nReceived Ctrl+C, shutting down all related processes...")
kill_all_related_processes() kill_all_related_processes()
sys.exit(0) sys.exit(0)
...@@ -79,11 +80,11 @@ if __name__ == "__main__": ...@@ -79,11 +80,11 @@ if __name__ == "__main__":
parser.add_argument("--config_json", type=str, required=True) parser.add_argument("--config_json", type=str, required=True)
parser.add_argument("--port", type=int, default=8000) parser.add_argument("--port", type=int, default=8000)
args = parser.parse_args() args = parser.parse_args()
print(f"args: {args}") logger.info(f"args: {args}")
with ProfilingContext("Init Server Cost"): with ProfilingContext("Init Server Cost"):
config = set_config(args) config = set_config(args)
print(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}") logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
runner = init_runner(config) runner = init_runner(config)
uvicorn.run(app, host="0.0.0.0", port=config.port, reload=False, workers=1) uvicorn.run(app, host="0.0.0.0", port=config.port, reload=False, workers=1)
from typing import Optional from typing import Optional
from loguru import logger
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -21,7 +22,7 @@ class RingComm: ...@@ -21,7 +22,7 @@ class RingComm:
def send_recv(self, to_send: torch.Tensor, recv_tensor: Optional[torch.Tensor] = None) -> torch.Tensor: def send_recv(self, to_send: torch.Tensor, recv_tensor: Optional[torch.Tensor] = None) -> torch.Tensor:
if recv_tensor is None: if recv_tensor is None:
res = torch.empty_like(to_send) res = torch.empty_like(to_send)
# print(f"send_recv: empty_like {to_send.shape}") # logger.info(f"send_recv: empty_like {to_send.shape}")
else: else:
res = recv_tensor res = recv_tensor
......
...@@ -2,6 +2,7 @@ import torch ...@@ -2,6 +2,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from lightx2v.attentions import attention from lightx2v.attentions import attention
from lightx2v.utils.utils import seed_all from lightx2v.utils.utils import seed_all
from loguru import logger
seed_all(42) seed_all(42)
...@@ -65,10 +66,10 @@ def test_part_head(): ...@@ -65,10 +66,10 @@ def test_part_head():
# 验证结果一致性 # 验证结果一致性
if cur_rank == 0: if cur_rank == 0:
# import pdb; pdb.set_trace() # import pdb; pdb.set_trace()
print("Outputs match:", torch.allclose(single_gpu_output, combined_output, rtol=1e-3, atol=1e-3)) logger.info("Outputs match:", torch.allclose(single_gpu_output, combined_output, rtol=1e-3, atol=1e-3))
# # 验证结果一致性 # # 验证结果一致性
# print("Outputs match:", torch.allclose(single_gpu_output, combined_output, rtol=1e-3, atol=1e-3)) # logger.info("Outputs match:", torch.allclose(single_gpu_output, combined_output, rtol=1e-3, atol=1e-3))
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -20,6 +20,7 @@ import os ...@@ -20,6 +20,7 @@ import os
import tensorrt as trt import tensorrt as trt
from .common_runtime import * from .common_runtime import *
from loguru import logger
try: try:
# Sometimes python does not understand FileNotFoundError # Sometimes python does not understand FileNotFoundError
...@@ -67,11 +68,11 @@ def find_sample_data(description="Runs a TensorRT Python sample", subfolder="", ...@@ -67,11 +68,11 @@ def find_sample_data(description="Runs a TensorRT Python sample", subfolder="",
data_path = os.path.join(data_dir, subfolder) data_path = os.path.join(data_dir, subfolder)
if not os.path.exists(data_path): if not os.path.exists(data_path):
if data_dir != kDEFAULT_DATA_ROOT: if data_dir != kDEFAULT_DATA_ROOT:
print("WARNING: " + data_path + " does not exist. Trying " + data_dir + " instead.") logger.info("WARNING: " + data_path + " does not exist. Trying " + data_dir + " instead.")
data_path = data_dir data_path = data_dir
# Make sure data directory exists. # Make sure data directory exists.
if not (os.path.exists(data_path)) and data_dir != kDEFAULT_DATA_ROOT: if not (os.path.exists(data_path)) and data_dir != kDEFAULT_DATA_ROOT:
print("WARNING: {:} does not exist. Please provide the correct data path with the -d option.".format(data_path)) logger.info("WARNING: {:} does not exist. Please provide the correct data path with the -d option.".format(data_path))
return data_path return data_path
data_paths = [get_data_path(data_dir) for data_dir in args.datadir] data_paths = [get_data_path(data_dir) for data_dir in args.datadir]
......
...@@ -4,6 +4,7 @@ from vllm import _custom_ops as ops ...@@ -4,6 +4,7 @@ from vllm import _custom_ops as ops
import sgl_kernel import sgl_kernel
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
from lightx2v.utils.quant_utils import IntegerQuantizer, FloatQuantizer from lightx2v.utils.quant_utils import IntegerQuantizer, FloatQuantizer
from loguru import logger
try: try:
import q8_kernels.functional as Q8F import q8_kernels.functional as Q8F
...@@ -461,7 +462,7 @@ if __name__ == "__main__": ...@@ -461,7 +462,7 @@ if __name__ == "__main__":
mm_weight.load(weight_dict) mm_weight.load(weight_dict)
input_tensor = torch.randn(1024, 4096).to(torch.bfloat16).cuda() input_tensor = torch.randn(1024, 4096).to(torch.bfloat16).cuda()
output_tensor = mm_weight.apply(input_tensor) output_tensor = mm_weight.apply(input_tensor)
print(output_tensor.shape) logger.info(output_tensor.shape)
weight_dict = { weight_dict = {
"xx.weight": torch.randn(8192, 4096), "xx.weight": torch.randn(8192, 4096),
...@@ -473,7 +474,7 @@ if __name__ == "__main__": ...@@ -473,7 +474,7 @@ if __name__ == "__main__":
mm_weight.load(weight_dict) mm_weight.load(weight_dict)
input_tensor = torch.randn(1024, 4096).to(torch.bfloat16).cuda() input_tensor = torch.randn(1024, 4096).to(torch.bfloat16).cuda()
output_tensor = mm_weight.apply(input_tensor) output_tensor = mm_weight.apply(input_tensor)
print(output_tensor.shape) logger.info(output_tensor.shape)
weight_dict = { weight_dict = {
"xx.weight": torch.randn(8192, 4096), "xx.weight": torch.randn(8192, 4096),
...@@ -485,4 +486,4 @@ if __name__ == "__main__": ...@@ -485,4 +486,4 @@ if __name__ == "__main__":
mm_weight.load(weight_dict) mm_weight.load(weight_dict)
input_tensor = torch.randn(1024, 4096).to(torch.bfloat16).cuda() input_tensor = torch.randn(1024, 4096).to(torch.bfloat16).cuda()
output_tensor = mm_weight.apply(input_tensor) output_tensor = mm_weight.apply(input_tensor)
print(output_tensor.shape) logger.info(output_tensor.shape)
...@@ -15,6 +15,7 @@ from lightx2v.models.runners.wan.wan_causal_runner import WanCausalRunner ...@@ -15,6 +15,7 @@ from lightx2v.models.runners.wan.wan_causal_runner import WanCausalRunner
from lightx2v.models.runners.graph_runner import GraphRunner from lightx2v.models.runners.graph_runner import GraphRunner
from lightx2v.common.ops import * from lightx2v.common.ops import *
from loguru import logger
def init_runner(config): def init_runner(config):
...@@ -37,17 +38,17 @@ if __name__ == "__main__": ...@@ -37,17 +38,17 @@ if __name__ == "__main__":
parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v") parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
parser.add_argument("--model_path", type=str, required=True) parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--config_json", type=str, required=True) parser.add_argument("--config_json", type=str, required=True)
parser.add_argument("--enable_cfg", type=bool, default=False)
parser.add_argument("--prompt", type=str, required=True) parser.add_argument("--prompt", type=str, required=True)
parser.add_argument("--negative_prompt", type=str, default="") parser.add_argument("--negative_prompt", type=str, default="")
parser.add_argument("--image_path", type=str, default="", help="The path to input image file or path for image-to-video (i2v) task") parser.add_argument("--image_path", type=str, default="", help="The path to input image file or path for image-to-video (i2v) task")
parser.add_argument("--save_video_path", type=str, default="./output_lightx2v.mp4", help="The path to save video path/file") parser.add_argument("--save_video_path", type=str, default="./output_lightx2v.mp4", help="The path to save video path/file")
args = parser.parse_args() args = parser.parse_args()
print(f"args: {args}") logger.info(f"args: {args}")
with ProfilingContext("Total Cost"): with ProfilingContext("Total Cost"):
config = set_config(args) config = set_config(args)
print(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}") logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
runner = init_runner(config) runner = init_runner(config)
runner.run_pipeline() runner.run_pipeline()
import torch import torch
from transformers import CLIPTextModel, AutoTokenizer from transformers import CLIPTextModel, AutoTokenizer
from loguru import logger
class TextEncoderHFClipModel: class TextEncoderHFClipModel:
...@@ -54,4 +55,4 @@ if __name__ == "__main__": ...@@ -54,4 +55,4 @@ if __name__ == "__main__":
model = TextEncoderHFClipModel(model_path, torch.device("cuda")) model = TextEncoderHFClipModel(model_path, torch.device("cuda"))
text = "A cat walks on the grass, realistic style." text = "A cat walks on the grass, realistic style."
outputs = model.infer(text) outputs = model.infer(text)
print(outputs) logger.info(outputs)
import torch import torch
from transformers import AutoModel, AutoTokenizer from transformers import AutoModel, AutoTokenizer
from loguru import logger
class TextEncoderHFLlamaModel: class TextEncoderHFLlamaModel:
...@@ -67,4 +68,4 @@ if __name__ == "__main__": ...@@ -67,4 +68,4 @@ if __name__ == "__main__":
model = TextEncoderHFLlamaModel(model_path, torch.device("cuda")) model = TextEncoderHFLlamaModel(model_path, torch.device("cuda"))
text = "A cat walks on the grass, realistic style." text = "A cat walks on the grass, realistic style."
outputs = model.infer(text) outputs = model.infer(text)
print(outputs) logger.info(outputs)
...@@ -3,6 +3,7 @@ from PIL import Image ...@@ -3,6 +3,7 @@ from PIL import Image
import numpy as np import numpy as np
import torchvision.transforms as transforms import torchvision.transforms as transforms
from transformers import LlavaForConditionalGeneration, CLIPImageProcessor, AutoTokenizer from transformers import LlavaForConditionalGeneration, CLIPImageProcessor, AutoTokenizer
from loguru import logger
def generate_crop_size_list(base_size=256, patch_size=32, max_ratio=4.0): def generate_crop_size_list(base_size=256, patch_size=32, max_ratio=4.0):
...@@ -158,4 +159,4 @@ if __name__ == "__main__": ...@@ -158,4 +159,4 @@ if __name__ == "__main__":
img_path = "/mtc/yongyang/projects/lightx2v/assets/inputs/imgs/img_1.jpg" img_path = "/mtc/yongyang/projects/lightx2v/assets/inputs/imgs/img_1.jpg"
img = Image.open(img_path).convert("RGB") img = Image.open(img_path).convert("RGB")
outputs = model.infer(text, img, None) outputs = model.infer(text, img, None)
print(outputs) logger.info(outputs)
...@@ -8,6 +8,7 @@ import torch.nn as nn ...@@ -8,6 +8,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .tokenizer import HuggingfaceTokenizer from .tokenizer import HuggingfaceTokenizer
from loguru import logger
__all__ = [ __all__ = [
"T5Model", "T5Model",
...@@ -522,4 +523,4 @@ if __name__ == "__main__": ...@@ -522,4 +523,4 @@ if __name__ == "__main__":
) )
text = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." text = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
outputs = model.infer(text) outputs = model.infer(text)
print(outputs) logger.info(outputs)
...@@ -10,6 +10,7 @@ import torchvision.transforms as T ...@@ -10,6 +10,7 @@ import torchvision.transforms as T
from lightx2v.attentions import attention from lightx2v.attentions import attention
from lightx2v.models.input_encoders.hf.t5.tokenizer import HuggingfaceTokenizer from lightx2v.models.input_encoders.hf.t5.tokenizer import HuggingfaceTokenizer
from loguru import logger
from .xlm_roberta import XLMRoberta from .xlm_roberta import XLMRoberta
...@@ -190,7 +191,7 @@ class VisionTransformer(nn.Module): ...@@ -190,7 +191,7 @@ class VisionTransformer(nn.Module):
norm_eps=1e-5, norm_eps=1e-5,
): ):
if image_size % patch_size != 0: if image_size % patch_size != 0:
print("[WARNING] image_size is not divisible by patch_size", flush=True) logger.info("[WARNING] image_size is not divisible by patch_size", flush=True)
assert pool_type in ("token", "token_fc", "attn_pool") assert pool_type in ("token", "token_fc", "attn_pool")
out_dim = out_dim or dim out_dim = out_dim or dim
super().__init__() super().__init__()
......
...@@ -123,7 +123,9 @@ class WanModel: ...@@ -123,7 +123,9 @@ class WanModel:
self.scheduler.cnt += 1 self.scheduler.cnt += 1
if self.scheduler.cnt >= self.scheduler.num_steps: if self.scheduler.cnt >= self.scheduler.num_steps:
self.scheduler.cnt = 0 self.scheduler.cnt = 0
self.scheduler.noise_pred = noise_pred_cond
if self.config["enable_cfg"]:
embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=False) embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=False)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out) x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_uncond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0] noise_pred_uncond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
......
...@@ -4,6 +4,7 @@ import torch.distributed as dist ...@@ -4,6 +4,7 @@ import torch.distributed as dist
from lightx2v.utils.profiler import ProfilingContext4Debug, ProfilingContext from lightx2v.utils.profiler import ProfilingContext4Debug, ProfilingContext
from lightx2v.utils.utils import save_videos_grid, cache_video from lightx2v.utils.utils import save_videos_grid, cache_video
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from loguru import logger
class DefaultRunner: class DefaultRunner:
...@@ -32,7 +33,7 @@ class DefaultRunner: ...@@ -32,7 +33,7 @@ class DefaultRunner:
def run(self): def run(self):
for step_index in range(self.model.scheduler.infer_steps): for step_index in range(self.model.scheduler.infer_steps):
print(f"==> step_index: {step_index + 1} / {self.model.scheduler.infer_steps}") logger.info(f"==> step_index: {step_index + 1} / {self.model.scheduler.infer_steps}")
with ProfilingContext4Debug("step_pre"): with ProfilingContext4Debug("step_pre"):
self.model.scheduler.step_pre(step_index=step_index) self.model.scheduler.step_pre(step_index=step_index)
......
from lightx2v.utils.profiler import ProfilingContext4Debug from lightx2v.utils.profiler import ProfilingContext4Debug
from loguru import logger
class GraphRunner: class GraphRunner:
...@@ -7,10 +8,10 @@ class GraphRunner: ...@@ -7,10 +8,10 @@ class GraphRunner:
self.compile() self.compile()
def compile(self): def compile(self):
print("start compile...") logger.info("start compile...")
with ProfilingContext4Debug("compile"): with ProfilingContext4Debug("compile"):
self.runner.run_step() self.runner.run_step()
print("end compile...") logger.info("end compile...")
def run_pipeline(self): def run_pipeline(self):
return self.runner.run_pipeline() return self.runner.run_pipeline()
...@@ -14,6 +14,7 @@ from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel ...@@ -14,6 +14,7 @@ from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel
from lightx2v.models.networks.wan.causal_model import WanCausalModel from lightx2v.models.networks.wan.causal_model import WanCausalModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
from loguru import logger
import torch.distributed as dist import torch.distributed as dist
...@@ -54,7 +55,7 @@ class WanCausalRunner(WanRunner): ...@@ -54,7 +55,7 @@ class WanCausalRunner(WanRunner):
lora_wrapper = WanLoraWrapper(model) lora_wrapper = WanLoraWrapper(model)
lora_name = lora_wrapper.load_lora(self.config.lora_path) lora_name = lora_wrapper.load_lora(self.config.lora_path)
lora_wrapper.apply_lora(lora_name, self.config.strength_model) lora_wrapper.apply_lora(lora_name, self.config.strength_model)
print(f"Loaded LoRA: {lora_name}") logger.info(f"Loaded LoRA: {lora_name}")
vae_model = WanVAE(vae_pth=os.path.join(self.config.model_path, "Wan2.1_VAE.pth"), device=init_device, parallel=self.config.parallel_vae) vae_model = WanVAE(vae_pth=os.path.join(self.config.model_path, "Wan2.1_VAE.pth"), device=init_device, parallel=self.config.parallel_vae)
if self.config.task == "i2v": if self.config.task == "i2v":
...@@ -95,13 +96,13 @@ class WanCausalRunner(WanRunner): ...@@ -95,13 +96,13 @@ class WanCausalRunner(WanRunner):
start_block_idx = 0 start_block_idx = 0
for fragment_idx in range(self.num_fragments): for fragment_idx in range(self.num_fragments):
print(f"=======> fragment_idx: {fragment_idx + 1} / {self.num_fragments}") logger.info(f"=======> fragment_idx: {fragment_idx + 1} / {self.num_fragments}")
kv_start = 0 kv_start = 0
kv_end = kv_start + self.num_frame_per_block * self.frame_seq_length kv_end = kv_start + self.num_frame_per_block * self.frame_seq_length
if fragment_idx > 0: if fragment_idx > 0:
print("recompute the kv_cache ...") logger.info("recompute the kv_cache ...")
with ProfilingContext4Debug("step_pre"): with ProfilingContext4Debug("step_pre"):
self.model.scheduler.latents = self.model.scheduler.last_sample self.model.scheduler.latents = self.model.scheduler.last_sample
self.model.scheduler.step_pre(step_index=self.model.scheduler.infer_steps - 1) self.model.scheduler.step_pre(step_index=self.model.scheduler.infer_steps - 1)
...@@ -115,12 +116,12 @@ class WanCausalRunner(WanRunner): ...@@ -115,12 +116,12 @@ class WanCausalRunner(WanRunner):
infer_blocks = self.infer_blocks - (fragment_idx > 0) infer_blocks = self.infer_blocks - (fragment_idx > 0)
for block_idx in range(infer_blocks): for block_idx in range(infer_blocks):
print(f"=======> block_idx: {block_idx + 1} / {infer_blocks}") logger.info(f"=======> block_idx: {block_idx + 1} / {infer_blocks}")
print(f"=======> kv_start: {kv_start}, kv_end: {kv_end}") logger.info(f"=======> kv_start: {kv_start}, kv_end: {kv_end}")
self.model.scheduler.reset() self.model.scheduler.reset()
for step_index in range(self.model.scheduler.infer_steps): for step_index in range(self.model.scheduler.infer_steps):
print(f"==> step_index: {step_index + 1} / {self.model.scheduler.infer_steps}") logger.info(f"==> step_index: {step_index + 1} / {self.model.scheduler.infer_steps}")
with ProfilingContext4Debug("step_pre"): with ProfilingContext4Debug("step_pre"):
self.model.scheduler.step_pre(step_index=step_index) self.model.scheduler.step_pre(step_index=step_index)
......
...@@ -14,6 +14,7 @@ from lightx2v.models.networks.wan.model import WanModel ...@@ -14,6 +14,7 @@ from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
import torch.distributed as dist import torch.distributed as dist
from loguru import logger
@RUNNER_REGISTER("wan2.1") @RUNNER_REGISTER("wan2.1")
...@@ -47,7 +48,7 @@ class WanRunner(DefaultRunner): ...@@ -47,7 +48,7 @@ class WanRunner(DefaultRunner):
lora_wrapper = WanLoraWrapper(model) lora_wrapper = WanLoraWrapper(model)
lora_name = lora_wrapper.load_lora(self.config.lora_path) lora_name = lora_wrapper.load_lora(self.config.lora_path)
lora_wrapper.apply_lora(lora_name, self.config.strength_model) lora_wrapper.apply_lora(lora_name, self.config.strength_model)
print(f"Loaded LoRA: {lora_name}") logger.info(f"Loaded LoRA: {lora_name}")
vae_model = WanVAE(vae_pth=os.path.join(self.config.model_path, "Wan2.1_VAE.pth"), device=init_device, parallel=self.config.parallel_vae) vae_model = WanVAE(vae_pth=os.path.join(self.config.model_path, "Wan2.1_VAE.pth"), device=init_device, parallel=self.config.parallel_vae)
if self.config.task == "i2v": if self.config.task == "i2v":
......
...@@ -7,6 +7,7 @@ import torch.nn as nn ...@@ -7,6 +7,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.distributed as dist import torch.distributed as dist
from einops import rearrange from einops import rearrange
from loguru import logger
__all__ = [ __all__ = [
"WanVAE", "WanVAE",
...@@ -801,7 +802,7 @@ class WanVAE: ...@@ -801,7 +802,7 @@ class WanVAE:
split_dim = 2 split_dim = 2
images = self.decode_dist(zs, world_size, cur_rank, split_dim) images = self.decode_dist(zs, world_size, cur_rank, split_dim)
else: else:
print("Fall back to naive decode mode") logger.info("Fall back to naive decode mode")
images = self.model.decode(zs.unsqueeze(0), self.scale).float().clamp_(-1, 1) images = self.model.decode(zs.unsqueeze(0), self.scale).float().clamp_(-1, 1)
else: else:
images = self.model.decode(zs.unsqueeze(0), self.scale).float().clamp_(-1, 1) images = self.model.decode(zs.unsqueeze(0), self.scale).float().clamp_(-1, 1)
......
...@@ -2,6 +2,7 @@ import time ...@@ -2,6 +2,7 @@ import time
import torch import torch
from contextlib import ContextDecorator from contextlib import ContextDecorator
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from loguru import logger
class _ProfilingContext(ContextDecorator): class _ProfilingContext(ContextDecorator):
...@@ -16,7 +17,7 @@ class _ProfilingContext(ContextDecorator): ...@@ -16,7 +17,7 @@ class _ProfilingContext(ContextDecorator):
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
torch.cuda.synchronize() torch.cuda.synchronize()
elapsed = time.perf_counter() - self.start_time elapsed = time.perf_counter() - self.start_time
print(f"[Profile] {self.name} cost {elapsed:.6f} seconds") logger.info(f"[Profile] {self.name} cost {elapsed:.6f} seconds")
return False return False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment