Commit 56af41eb authored by helloyongyang's avatar helloyongyang
Browse files

Big Refactor

parent 142f6872
......@@ -27,7 +27,7 @@ RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.11 1
RUN pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple \
&& pip install packaging ninja vllm torch torchvision diffusers transformers \
tokenizers accelerate safetensors opencv-python numpy imageio imageio-ffmpeg \
einops loguru sgl-kernel qtorch ftfy
einops loguru sgl-kernel qtorch ftfy easydict
# install flash-attention 2
RUN cd lightx2v/3rd/flash-attention && pip install --no-cache-dir -v -e .
......
{
"infer_steps": 20,
"target_video_length": 33,
"target_height": 720,
"target_width": 1280,
"attention_type": "flash_attn3",
"seed": 42,
"mm_config": {
"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm",
"weight_auto_quant": true
},
"feature_caching": "TaylorSeer"
}
{
"infer_steps": 20,
"target_video_length": 33,
"target_height": 720,
"target_width": 1280,
"attention_type": "flash_attn3",
"seed": 42,
"mm_config": {
"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm",
"weight_auto_quant": true
},
"parallel_attn_type": "ring"
}
{
"infer_steps": 20,
"target_video_length": 33,
"target_height": 720,
"target_width": 1280,
"attention_type": "flash_attn3",
"seed": 42,
"mm_config": {
"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm",
"weight_auto_quant": true
},
"parallel_attn_type": "ulysses"
}
{
"infer_steps": 20,
"target_video_length": 33,
"i2v_resolution": "720p",
"attention_type": "flash_attn3",
"seed": 0,
"mm_config": {
"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm",
"weight_auto_quant": true
}
}
{
"infer_steps": 20,
"target_video_length": 33,
"target_height": 720,
"target_width": 1280,
"attention_type": "flash_attn3",
"seed": 42,
"mm_config": {
"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm",
"weight_auto_quant": true
}
}
{
"infer_steps": 40,
"target_video_length": 81,
"target_height": 480,
"target_width": 832,
"attention_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 5,
"sample_shift": 5,
"mm_config": {
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl",
"weight_auto_quant": true
}
}
{
"infer_steps": 50,
"target_video_length": 81,
"target_height": 480,
"target_width": 832,
"attention_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 6,
"sample_shift": 8,
"mm_config": {
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl",
"weight_auto_quant": true
}
}
import argparse
import torch
import torch.distributed as dist
import os
import time
import gc
import json
import torchvision
import torchvision.transforms.functional as TF
import numpy as np
from PIL import Image
from lightx2v.utils.envs import *
from lightx2v.utils.utils import save_videos_grid, seed_all, cache_video
from lightx2v.utils.profiler import ProfilingContext, ProfilingContext4Debug
from lightx2v.utils.utils import seed_all
from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.set_config import set_config
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.models.input_encoders.hf.llama.model import TextEncoderHFLlamaModel
from lightx2v.models.input_encoders.hf.clip.model import TextEncoderHFClipModel
from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel
from lightx2v.models.input_encoders.hf.llava.model import TextEncoderHFLlavaModel
from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel
from lightx2v.models.schedulers.hunyuan.scheduler import HunyuanScheduler
from lightx2v.models.schedulers.hunyuan.feature_caching.scheduler import HunyuanSchedulerTaylorCaching, HunyuanSchedulerTeaCaching
from lightx2v.models.schedulers.wan.scheduler import WanScheduler
from lightx2v.models.schedulers.wan.feature_caching.scheduler import WanSchedulerTeaCaching
from lightx2v.models.networks.hunyuan.model import HunyuanModel
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.video_encoders.hf.autoencoder_kl_causal_3d.model import VideoEncoderKLCausal3DModel
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner
from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.runners.graph_runner import GraphRunner
from lightx2v.common.ops import *
def load_models(config):
if config["parallel_attn_type"]:
cur_rank = dist.get_rank() # 获取当前进程的 rank
torch.cuda.set_device(cur_rank) # 设置当前进程的 CUDA 设备
image_encoder = None
if config.cpu_offload:
init_device = torch.device("cpu")
else:
init_device = torch.device("cuda")
if config.model_cls == "hunyuan":
if config.task == "t2v":
text_encoder_1 = TextEncoderHFLlamaModel(os.path.join(config.model_path, "text_encoder"), init_device)
else:
text_encoder_1 = TextEncoderHFLlavaModel(os.path.join(config.model_path, "text_encoder_i2v"), init_device)
text_encoder_2 = TextEncoderHFClipModel(os.path.join(config.model_path, "text_encoder_2"), init_device)
text_encoders = [text_encoder_1, text_encoder_2]
model = HunyuanModel(config.model_path, config, init_device, config)
vae_model = VideoEncoderKLCausal3DModel(config.model_path, dtype=torch.float16, device=init_device, config=config)
elif config.model_cls == "wan2.1":
with ProfilingContext("Load Text Encoder"):
text_encoder = T5EncoderModel(
text_len=config["text_len"],
dtype=torch.bfloat16,
device=init_device,
checkpoint_path=os.path.join(config.model_path, "models_t5_umt5-xxl-enc-bf16.pth"),
tokenizer_path=os.path.join(config.model_path, "google/umt5-xxl"),
shard_fn=None,
)
text_encoders = [text_encoder]
with ProfilingContext("Load Wan Model"):
model = WanModel(config.model_path, config, init_device)
if config.lora_path:
lora_wrapper = WanLoraWrapper(model)
with ProfilingContext("Load LoRA Model"):
lora_name = lora_wrapper.load_lora(config.lora_path)
lora_wrapper.apply_lora(lora_name, config.strength_model)
print(f"Loaded LoRA: {lora_name}")
with ProfilingContext("Load WAN VAE Model"):
vae_model = WanVAE(vae_pth=os.path.join(config.model_path, "Wan2.1_VAE.pth"), device=init_device, parallel=config.parallel_vae)
if config.task == "i2v":
with ProfilingContext("Load Image Encoder"):
image_encoder = CLIPModel(
dtype=torch.float16,
device=init_device,
checkpoint_path=os.path.join(config.model_path, "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"),
tokenizer_path=os.path.join(config.model_path, "xlm-roberta-large"),
)
else:
raise NotImplementedError(f"Unsupported model class: {config.model_cls}")
return model, text_encoders, vae_model, image_encoder
def set_target_shape(config, image_encoder_output):
if config.model_cls == "hunyuan":
if config.task == "t2v":
vae_scale_factor = 2 ** (4 - 1)
config.target_shape = (
1,
16,
(config.target_video_length - 1) // 4 + 1,
int(config.target_height) // vae_scale_factor,
int(config.target_width) // vae_scale_factor,
)
elif config.task == "i2v":
vae_scale_factor = 2 ** (4 - 1)
config.target_shape = (
1,
16,
(config.target_video_length - 1) // 4 + 1,
int(image_encoder_output["target_height"]) // vae_scale_factor,
int(image_encoder_output["target_width"]) // vae_scale_factor,
)
elif config.model_cls == "wan2.1":
if config.task == "i2v":
config.target_shape = (16, 21, config.lat_h, config.lat_w)
elif config.task == "t2v":
config.target_shape = (
16,
(config.target_video_length - 1) // 4 + 1,
int(config.target_height) // config.vae_stride[1],
int(config.target_width) // config.vae_stride[2],
)
def generate_crop_size_list(base_size=256, patch_size=32, max_ratio=4.0):
num_patches = round((base_size / patch_size) ** 2)
assert max_ratio >= 1.0
crop_size_list = []
wp, hp = num_patches, 1
while wp > 0:
if max(wp, hp) / min(wp, hp) <= max_ratio:
crop_size_list.append((wp * patch_size, hp * patch_size))
if (hp + 1) * wp <= num_patches:
hp += 1
else:
wp -= 1
return crop_size_list
def get_closest_ratio(height: float, width: float, ratios: list, buckets: list):
aspect_ratio = float(height) / float(width)
diff_ratios = ratios - aspect_ratio
if aspect_ratio >= 1:
indices = [(index, x) for index, x in enumerate(diff_ratios) if x <= 0]
else:
indices = [(index, x) for index, x in enumerate(diff_ratios) if x > 0]
closest_ratio_id = min(indices, key=lambda pair: abs(pair[1]))[0]
closest_size = buckets[closest_ratio_id]
closest_ratio = ratios[closest_ratio_id]
return closest_size, closest_ratio
def run_image_encoder(config, image_encoder, vae_model):
if config.model_cls == "hunyuan":
img = Image.open(config.image_path).convert("RGB")
origin_size = img.size
i2v_resolution = "720p"
if i2v_resolution == "720p":
bucket_hw_base_size = 960
elif i2v_resolution == "540p":
bucket_hw_base_size = 720
elif i2v_resolution == "360p":
bucket_hw_base_size = 480
else:
raise ValueError(f"i2v_resolution: {i2v_resolution} must be in [360p, 540p, 720p]")
crop_size_list = generate_crop_size_list(bucket_hw_base_size, 32)
aspect_ratios = np.array([round(float(h) / float(w), 5) for h, w in crop_size_list])
closest_size, closest_ratio = get_closest_ratio(origin_size[1], origin_size[0], aspect_ratios, crop_size_list)
resize_param = min(closest_size)
center_crop_param = closest_size
ref_image_transform = torchvision.transforms.Compose(
[torchvision.transforms.Resize(resize_param), torchvision.transforms.CenterCrop(center_crop_param), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize([0.5], [0.5])]
)
semantic_image_pixel_values = [ref_image_transform(img)]
semantic_image_pixel_values = torch.cat(semantic_image_pixel_values).unsqueeze(0).unsqueeze(2).to(torch.float16).to(torch.device("cuda"))
img_latents = vae_model.encode(semantic_image_pixel_values, config).mode()
scaling_factor = 0.476986
img_latents.mul_(scaling_factor)
target_height, target_width = closest_size
return {"img": img, "img_latents": img_latents, "target_height": target_height, "target_width": target_width}
elif config.model_cls == "wan2.1":
img = Image.open(config.image_path).convert("RGB")
img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
clip_encoder_out = image_encoder.visual([img[:, None, :, :]], config).squeeze(0).to(torch.bfloat16)
h, w = img.shape[1:]
aspect_ratio = h / w
max_area = config.target_height * config.target_width
lat_h = round(np.sqrt(max_area * aspect_ratio) // config.vae_stride[1] // config.patch_size[1] * config.patch_size[1])
lat_w = round(np.sqrt(max_area / aspect_ratio) // config.vae_stride[2] // config.patch_size[2] * config.patch_size[2])
h = lat_h * config.vae_stride[1]
w = lat_w * config.vae_stride[2]
config.lat_h = lat_h
config.lat_w = lat_w
msk = torch.ones(1, 81, lat_h, lat_w, device=torch.device("cuda"))
msk[:, 1:] = 0
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
msk = msk.transpose(1, 2)[0]
vae_encode_out = vae_model.encode(
[torch.concat([torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode="bicubic").transpose(0, 1), torch.zeros(3, 80, h, w)], dim=1).cuda()], config
)[0]
vae_encode_out = torch.concat([msk, vae_encode_out]).to(torch.bfloat16)
return {"clip_encoder_out": clip_encoder_out, "vae_encode_out": vae_encode_out}
else:
raise NotImplementedError(f"Unsupported model class: {config.model_cls}")
def run_text_encoder(text, text_encoders, config, image_encoder_output):
text_encoder_output = {}
if config.model_cls == "hunyuan":
for i, encoder in enumerate(text_encoders):
if config.task == "i2v" and i == 0:
text_state, attention_mask = encoder.infer(text, image_encoder_output["img"], config)
else:
text_state, attention_mask = encoder.infer(text, config)
text_encoder_output[f"text_encoder_{i + 1}_text_states"] = text_state.to(dtype=torch.bfloat16)
text_encoder_output[f"text_encoder_{i + 1}_attention_mask"] = attention_mask
elif config.model_cls == "wan2.1":
n_prompt = config.get("sample_neg_prompt", "")
context = text_encoders[0].infer([text], config)
context_null = text_encoders[0].infer([n_prompt if n_prompt else ""], config)
text_encoder_output["context"] = context
text_encoder_output["context_null"] = context_null
else:
raise NotImplementedError(f"Unsupported model type: {config.model_cls}")
return text_encoder_output
def init_scheduler(config, image_encoder_output):
if config.model_cls == "hunyuan":
if config.feature_caching == "NoCaching":
scheduler = HunyuanScheduler(config, image_encoder_output)
elif config.feature_caching == "Tea":
scheduler = HunyuanSchedulerTeaCaching(config, image_encoder_output)
elif config.feature_caching == "TaylorSeer":
scheduler = HunyuanSchedulerTaylorCaching(config, image_encoder_output)
else:
raise NotImplementedError(f"Unsupported feature_caching type: {config.feature_caching}")
elif config.model_cls == "wan2.1":
if config.feature_caching == "NoCaching":
scheduler = WanScheduler(config)
elif config.feature_caching == "Tea":
scheduler = WanSchedulerTeaCaching(config)
else:
raise NotImplementedError(f"Unsupported feature_caching type: {config.feature_caching}")
else:
raise NotImplementedError(f"Unsupported model class: {config.model_cls}")
return scheduler
def run_vae(latents, generator, config):
images = vae_model.decode(latents, generator=generator, config=config)
return images
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan"], default="hunyuan")
......@@ -288,86 +24,23 @@ if __name__ == "__main__":
parser.add_argument("--image_path", type=str, default=None, help="The path to input image file or path for image-to-video (i2v) task")
parser.add_argument("--save_video_path", type=str, default="./output_lightx2v.mp4", help="The path to save video path/file")
parser.add_argument("--prompt", type=str, required=True)
parser.add_argument("--infer_steps", type=int, required=True)
parser.add_argument("--target_video_length", type=int, required=True)
parser.add_argument("--target_width", type=int, required=True)
parser.add_argument("--target_height", type=int, required=True)
parser.add_argument("--attention_type", type=str, required=True)
parser.add_argument("--sample_neg_prompt", type=str, default="")
parser.add_argument("--sample_guide_scale", type=float, default=5.0)
parser.add_argument("--sample_shift", type=float, default=5.0)
parser.add_argument("--do_mm_calib", action="store_true")
parser.add_argument("--cpu_offload", action="store_true")
parser.add_argument("--feature_caching", choices=["NoCaching", "TaylorSeer", "Tea"], default="NoCaching")
parser.add_argument("--mm_config", default=None)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--parallel_attn_type", default=None, choices=["ulysses", "ring"])
parser.add_argument("--parallel_vae", action="store_true")
parser.add_argument("--max_area", action="store_true")
parser.add_argument("--vae_stride", default=(4, 8, 8))
parser.add_argument("--patch_size", default=(1, 2, 2))
parser.add_argument("--teacache_thresh", type=float, default=0.26)
parser.add_argument("--use_ret_steps", action="store_true", default=False)
parser.add_argument("--use_bfloat16", action="store_true", default=True)
parser.add_argument("--lora_path", type=str, default=None)
parser.add_argument("--strength_model", type=float, default=1.0)
parser.add_argument("--negative_prompt", type=str, default="")
parser.add_argument("--config_json", type=str, required=True)
args = parser.parse_args()
start_time = time.perf_counter()
print(f"args: {args}")
seed_all(args.seed)
config = set_config(args)
if config.parallel_attn_type:
dist.init_process_group(backend="nccl")
print(f"config: {config}")
with ProfilingContext("Load models"):
model, text_encoders, vae_model, image_encoder = load_models(config)
if config["task"] in ["i2v"]:
image_encoder_output = run_image_encoder(config, image_encoder, vae_model)
else:
image_encoder_output = {"clip_encoder_out": None, "vae_encode_out": None}
with ProfilingContext("Run Text Encoder"):
text_encoder_output = run_text_encoder(config["prompt"], text_encoders, config, image_encoder_output)
with ProfilingContext("Total Cost"):
config = set_config(args)
print(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
inputs = {"text_encoder_output": text_encoder_output, "image_encoder_output": image_encoder_output}
seed_all(config.seed)
set_target_shape(config, image_encoder_output)
scheduler = init_scheduler(config, image_encoder_output)
if config.parallel_attn_type:
dist.init_process_group(backend="nccl")
model.set_scheduler(scheduler)
gc.collect()
torch.cuda.empty_cache()
if CHECK_ENABLE_GRAPH_MODE():
default_runner = DefaultRunner(model, inputs)
runner = GraphRunner(default_runner)
else:
runner = DefaultRunner(model, inputs)
latents, generator = runner.run()
if config.cpu_offload:
scheduler.clear()
del text_encoder_output, image_encoder_output, model, text_encoders, scheduler
torch.cuda.empty_cache()
with ProfilingContext("Run VAE"):
images = run_vae(latents, generator, config)
if not config.parallel_attn_type or (config.parallel_attn_type and dist.get_rank() == 0):
with ProfilingContext("Save video"):
if config.model_cls == "wan2.1":
cache_video(tensor=images, save_file=config.save_video_path, fps=16, nrow=1, normalize=True, value_range=(-1, 1))
else:
save_videos_grid(images, config.save_video_path, fps=24)
end_time = time.perf_counter()
print(f"Total cost: {end_time - start_time}")
if CHECK_ENABLE_GRAPH_MODE():
default_runner = RUNNER_REGISTER[config.model_cls](config)
runner = GraphRunner(default_runner)
else:
runner = RUNNER_REGISTER[config.model_cls](config)
runner.run_pipeline()
......@@ -102,7 +102,6 @@ class TextEncoderHFLlavaModel:
if config.cpu_offload:
self.to_cuda()
text = self.prompt_template.format(text)
print(f"text: {text}")
tokens = self.tokenizer(
text,
return_length=False,
......
......@@ -33,10 +33,10 @@ class WanPreInfer:
else:
context = inputs["text_encoder_output"]["context_null"]
seq_len = self.scheduler.seq_len
clip_fea = inputs["image_encoder_output"]["clip_encoder_out"]
y = [inputs["image_encoder_output"]["vae_encode_out"]]
if self.task == "i2v":
clip_fea = inputs["image_encoder_output"]["clip_encoder_out"]
y = [inputs["image_encoder_output"]["vae_encode_out"]]
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
# embeddings
......
from lightx2v.utils.profiler import ProfilingContext4Debug
import gc
import torch
import torch.distributed as dist
from lightx2v.utils.profiler import ProfilingContext4Debug, ProfilingContext
from lightx2v.utils.utils import save_videos_grid, cache_video
from lightx2v.utils.envs import *
class DefaultRunner:
def __init__(self, model, inputs):
self.model = model
self.inputs = inputs
def __init__(self, config):
self.config = config
self.model, self.text_encoders, self.vae_model, self.image_encoder = self.load_model()
def run_input_encoder(self):
image_encoder_output = None
if self.config["task"] == "i2v":
with ProfilingContext("Run Img Encoder"):
image_encoder_output = self.run_image_encoder(self.config, self.image_encoder, self.vae_model)
with ProfilingContext("Run Text Encoder"):
text_encoder_output = self.run_text_encoder(self.config["prompt"], self.text_encoders, self.config, image_encoder_output)
self.set_target_shape()
self.inputs = {"text_encoder_output": text_encoder_output, "image_encoder_output": image_encoder_output}
gc.collect()
torch.cuda.empty_cache()
def run(self):
for step_index in range(self.model.scheduler.infer_steps):
......@@ -22,6 +40,37 @@ class DefaultRunner:
return self.model.scheduler.latents, self.model.scheduler.generator
def run_step(self, step_index=0):
self.init_scheduler()
self.run_input_encoder()
self.model.scheduler.prepare(self.inputs["image_encoder_output"])
self.model.scheduler.step_pre(step_index=step_index)
self.model.infer(self.inputs)
self.model.scheduler.step_post()
def end_run(self):
if self.config.cpu_offload:
self.model.scheduler.clear()
del self.inputs, self.model.scheduler, self.model, self.text_encoders
torch.cuda.empty_cache()
@ProfilingContext("Run VAE")
def run_vae(self, latents, generator):
images = self.vae_model.decode(latents, generator=generator, config=self.config)
return images
@ProfilingContext("Save video")
def save_video(self, images):
if not self.config.parallel_attn_type or (self.config.parallel_attn_type and dist.get_rank() == 0):
if self.config.model_cls == "wan2.1":
cache_video(tensor=images, save_file=self.config.save_video_path, fps=16, nrow=1, normalize=True, value_range=(-1, 1))
else:
save_videos_grid(images, self.config.save_video_path, fps=24)
def run_pipeline(self):
self.init_scheduler()
self.run_input_encoder()
self.model.scheduler.prepare(self.inputs["image_encoder_output"])
latents, generator = self.run()
self.end_run()
images = self.run_vae(latents, generator)
self.save_video(images)
import copy
from lightx2v.utils.profiler import ProfilingContext4Debug
......@@ -8,16 +7,10 @@ class GraphRunner:
self.compile()
def compile(self):
scheduler = copy.deepcopy(self.runner.model.scheduler)
inputs = copy.deepcopy(self.runner.inputs)
print("start compile...")
with ProfilingContext4Debug("compile"):
self.runner.run_step()
print("end compile...")
self.runner.model.set_scheduler(scheduler)
setattr(self.runner, "inputs", inputs)
def run(self):
return self.runner.run()
def run_pipeline(self):
return self.runner.run_pipeline()
import os
import numpy as np
import torch
import torchvision
from PIL import Image
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.hunyuan.scheduler import HunyuanScheduler
from lightx2v.models.schedulers.hunyuan.feature_caching.scheduler import HunyuanSchedulerTaylorCaching, HunyuanSchedulerTeaCaching
from lightx2v.models.input_encoders.hf.llama.model import TextEncoderHFLlamaModel
from lightx2v.models.input_encoders.hf.clip.model import TextEncoderHFClipModel
from lightx2v.models.input_encoders.hf.llava.model import TextEncoderHFLlavaModel
from lightx2v.models.networks.hunyuan.model import HunyuanModel
from lightx2v.models.video_encoders.hf.autoencoder_kl_causal_3d.model import VideoEncoderKLCausal3DModel
import torch.distributed as dist
from lightx2v.utils.profiler import ProfilingContext
@RUNNER_REGISTER("hunyuan")
class HunyuanRunner(DefaultRunner):
def __init__(self, config):
super().__init__(config)
@ProfilingContext("Load models")
def load_model(self):
if self.config["parallel_attn_type"]:
cur_rank = dist.get_rank()
torch.cuda.set_device(cur_rank)
image_encoder = None
if self.config.cpu_offload:
init_device = torch.device("cpu")
else:
init_device = torch.device("cuda")
if self.config.task == "t2v":
text_encoder_1 = TextEncoderHFLlamaModel(os.path.join(self.config.model_path, "text_encoder"), init_device)
else:
text_encoder_1 = TextEncoderHFLlavaModel(os.path.join(self.config.model_path, "text_encoder_i2v"), init_device)
text_encoder_2 = TextEncoderHFClipModel(os.path.join(self.config.model_path, "text_encoder_2"), init_device)
text_encoders = [text_encoder_1, text_encoder_2]
model = HunyuanModel(self.config.model_path, self.config, init_device, self.config)
vae_model = VideoEncoderKLCausal3DModel(self.config.model_path, dtype=torch.float16, device=init_device, config=self.config)
return model, text_encoders, vae_model, image_encoder
def init_scheduler(self):
if self.config.feature_caching == "NoCaching":
scheduler = HunyuanScheduler(self.config)
elif self.config.feature_caching == "Tea":
scheduler = HunyuanSchedulerTeaCaching(self.config)
elif self.config.feature_caching == "TaylorSeer":
scheduler = HunyuanSchedulerTaylorCaching(self.config)
else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}")
self.model.set_scheduler(scheduler)
def run_text_encoder(self, text, text_encoders, config, image_encoder_output):
text_encoder_output = {}
for i, encoder in enumerate(text_encoders):
if config.task == "i2v" and i == 0:
text_state, attention_mask = encoder.infer(text, image_encoder_output["img"], config)
else:
text_state, attention_mask = encoder.infer(text, config)
text_encoder_output[f"text_encoder_{i + 1}_text_states"] = text_state.to(dtype=torch.bfloat16)
text_encoder_output[f"text_encoder_{i + 1}_attention_mask"] = attention_mask
return text_encoder_output
def get_closest_ratio(self, height: float, width: float, ratios: list, buckets: list):
aspect_ratio = float(height) / float(width)
diff_ratios = ratios - aspect_ratio
if aspect_ratio >= 1:
indices = [(index, x) for index, x in enumerate(diff_ratios) if x <= 0]
else:
indices = [(index, x) for index, x in enumerate(diff_ratios) if x > 0]
closest_ratio_id = min(indices, key=lambda pair: abs(pair[1]))[0]
closest_size = buckets[closest_ratio_id]
closest_ratio = ratios[closest_ratio_id]
return closest_size, closest_ratio
def generate_crop_size_list(self, base_size=256, patch_size=32, max_ratio=4.0):
num_patches = round((base_size / patch_size) ** 2)
assert max_ratio >= 1.0
crop_size_list = []
wp, hp = num_patches, 1
while wp > 0:
if max(wp, hp) / min(wp, hp) <= max_ratio:
crop_size_list.append((wp * patch_size, hp * patch_size))
if (hp + 1) * wp <= num_patches:
hp += 1
else:
wp -= 1
return crop_size_list
def run_image_encoder(self, config, image_encoder, vae_model):
img = Image.open(config.image_path).convert("RGB")
if config.i2v_resolution == "720p":
bucket_hw_base_size = 960
elif config.i2v_resolution == "540p":
bucket_hw_base_size = 720
elif config.i2v_resolution == "360p":
bucket_hw_base_size = 480
else:
raise ValueError(f"config.i2v_resolution: {config.i2v_resolution} must be in [360p, 540p, 720p]")
origin_size = img.size
crop_size_list = self.generate_crop_size_list(bucket_hw_base_size, 32)
aspect_ratios = np.array([round(float(h) / float(w), 5) for h, w in crop_size_list])
closest_size, closest_ratio = self.get_closest_ratio(origin_size[1], origin_size[0], aspect_ratios, crop_size_list)
config.target_height, config.target_width = closest_size
resize_param = min(closest_size)
center_crop_param = closest_size
ref_image_transform = torchvision.transforms.Compose(
[torchvision.transforms.Resize(resize_param), torchvision.transforms.CenterCrop(center_crop_param), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize([0.5], [0.5])]
)
semantic_image_pixel_values = [ref_image_transform(img)]
semantic_image_pixel_values = torch.cat(semantic_image_pixel_values).unsqueeze(0).unsqueeze(2).to(torch.float16).to(torch.device("cuda"))
img_latents = vae_model.encode(semantic_image_pixel_values, config).mode()
scaling_factor = 0.476986
img_latents.mul_(scaling_factor)
return {"img": img, "img_latents": img_latents}
def set_target_shape(self):
vae_scale_factor = 2 ** (4 - 1)
self.config.target_shape = (
1,
16,
(self.config.target_video_length - 1) // 4 + 1,
int(self.config.target_height) // vae_scale_factor,
int(self.config.target_width) // vae_scale_factor,
)
import os
import numpy as np
import torch
import torchvision.transforms.functional as TF
from PIL import Image
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.wan.scheduler import WanScheduler
from lightx2v.models.schedulers.wan.feature_caching.scheduler import WanSchedulerTeaCaching
from lightx2v.utils.profiler import ProfilingContext
from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel
from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
import torch.distributed as dist
@RUNNER_REGISTER("wan2.1")
class WanRunner(DefaultRunner):
def __init__(self, config):
super().__init__(config)
@ProfilingContext("Load models")
def load_model(self):
if self.config["parallel_attn_type"]:
cur_rank = dist.get_rank()
torch.cuda.set_device(cur_rank)
image_encoder = None
if self.config.cpu_offload:
init_device = torch.device("cpu")
else:
init_device = torch.device("cuda")
text_encoder = T5EncoderModel(
text_len=self.config["text_len"],
dtype=torch.bfloat16,
device=init_device,
checkpoint_path=os.path.join(self.config.model_path, "models_t5_umt5-xxl-enc-bf16.pth"),
tokenizer_path=os.path.join(self.config.model_path, "google/umt5-xxl"),
shard_fn=None,
)
text_encoders = [text_encoder]
model = WanModel(self.config.model_path, self.config, init_device)
if self.config.lora_path:
lora_wrapper = WanLoraWrapper(model)
lora_name = lora_wrapper.load_lora(self.config.lora_path)
lora_wrapper.apply_lora(lora_name, self.config.strength_model)
print(f"Loaded LoRA: {lora_name}")
vae_model = WanVAE(vae_pth=os.path.join(self.config.model_path, "Wan2.1_VAE.pth"), device=init_device, parallel=self.config.parallel_vae)
if self.config.task == "i2v":
image_encoder = CLIPModel(
dtype=torch.float16,
device=init_device,
checkpoint_path=os.path.join(self.config.model_path, "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"),
tokenizer_path=os.path.join(self.config.model_path, "xlm-roberta-large"),
)
return model, text_encoders, vae_model, image_encoder
def init_scheduler(self):
if self.config.feature_caching == "NoCaching":
scheduler = WanScheduler(self.config)
elif self.config.feature_caching == "Tea":
scheduler = WanSchedulerTeaCaching(self.config)
else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}")
self.model.set_scheduler(scheduler)
def run_text_encoder(self, text, text_encoders, config, image_encoder_output):
text_encoder_output = {}
n_prompt = config.get("negative_prompt", "")
context = text_encoders[0].infer([text], config)
context_null = text_encoders[0].infer([n_prompt if n_prompt else ""], config)
text_encoder_output["context"] = context
text_encoder_output["context_null"] = context_null
return text_encoder_output
def run_image_encoder(self, config, image_encoder, vae_model):
img = Image.open(config.image_path).convert("RGB")
img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
clip_encoder_out = image_encoder.visual([img[:, None, :, :]], config).squeeze(0).to(torch.bfloat16)
h, w = img.shape[1:]
aspect_ratio = h / w
max_area = config.target_height * config.target_width
lat_h = round(np.sqrt(max_area * aspect_ratio) // config.vae_stride[1] // config.patch_size[1] * config.patch_size[1])
lat_w = round(np.sqrt(max_area / aspect_ratio) // config.vae_stride[2] // config.patch_size[2] * config.patch_size[2])
h = lat_h * config.vae_stride[1]
w = lat_w * config.vae_stride[2]
config.lat_h = lat_h
config.lat_w = lat_w
msk = torch.ones(1, 81, lat_h, lat_w, device=torch.device("cuda"))
msk[:, 1:] = 0
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
msk = msk.transpose(1, 2)[0]
vae_encode_out = vae_model.encode(
[torch.concat([torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode="bicubic").transpose(0, 1), torch.zeros(3, 80, h, w)], dim=1).cuda()], config
)[0]
vae_encode_out = torch.concat([msk, vae_encode_out]).to(torch.bfloat16)
return {"clip_encoder_out": clip_encoder_out, "vae_encode_out": vae_encode_out}
def set_target_shape(self):
if self.config.task == "i2v":
self.config.target_shape = (16, 21, self.config.lat_h, self.config.lat_w)
elif self.config.task == "t2v":
self.config.target_shape = (
16,
(self.config.target_video_length - 1) // 4 + 1,
int(self.config.target_height) // self.config.vae_stride[1],
int(self.config.target_width) // self.config.vae_stride[2],
)
......@@ -4,8 +4,8 @@ import torch
class HunyuanSchedulerTeaCaching(HunyuanScheduler):
def __init__(self, args, image_encoder_output):
super().__init__(args, image_encoder_output)
def __init__(self, config):
super().__init__(config)
self.cnt = 0
self.num_steps = self.args.infer_steps
self.teacache_thresh = self.args.teacache_thresh
......@@ -26,8 +26,8 @@ class HunyuanSchedulerTeaCaching(HunyuanScheduler):
class HunyuanSchedulerTaylorCaching(HunyuanScheduler):
def __init__(self, args, image_encoder_output):
super().__init__(args, image_encoder_output)
def __init__(self, config):
super().__init__(config)
self.cache_dic, self.current = cache_init(self.infer_steps)
def step_pre(self, step_index):
......
......@@ -235,29 +235,27 @@ def get_1d_rotary_pos_embed_riflex(
class HunyuanScheduler(BaseScheduler):
def __init__(self, args, image_encoder_output):
super().__init__(args)
self.infer_steps = self.args.infer_steps
self.image_encoder_output = image_encoder_output
def __init__(self, config):
super().__init__(config)
self.infer_steps = self.config.infer_steps
self.shift = 7.0
self.timesteps, self.sigmas = set_timesteps_sigmas(self.infer_steps, self.shift, device=torch.device("cuda"))
assert len(self.timesteps) == self.infer_steps
self.embedded_guidance_scale = 6.0
self.generator = [torch.Generator("cuda").manual_seed(seed) for seed in [self.args.seed]]
self.generator = [torch.Generator("cuda").manual_seed(seed) for seed in [self.config.seed]]
self.noise_pred = None
self.prepare_latents(shape=self.args.target_shape, dtype=torch.float16)
def prepare(self, image_encoder_output):
self.image_encoder_output = image_encoder_output
self.prepare_latents(shape=self.config.target_shape, dtype=torch.float16, image_encoder_output=image_encoder_output)
self.prepare_guidance()
if self.args.task == "t2v":
target_height, target_width = self.args.target_height, self.args.target_width
else:
target_height, target_width = self.image_encoder_output["target_height"], self.image_encoder_output["target_width"]
self.prepare_rotary_pos_embedding(video_length=self.args.target_video_length, height=target_height, width=target_width)
self.prepare_rotary_pos_embedding(video_length=self.config.target_video_length, height=self.config.target_height, width=self.config.target_width)
def prepare_guidance(self):
self.guidance = torch.tensor([self.embedded_guidance_scale], dtype=torch.bfloat16, device=torch.device("cuda")) * 1000.0
def step_post(self):
if self.args.task == "t2v":
if self.config.task == "t2v":
sample = self.latents.to(torch.float32)
dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index]
self.latents = sample + self.noise_pred.to(torch.float32) * dt
......@@ -267,16 +265,16 @@ class HunyuanScheduler(BaseScheduler):
latents = sample + self.noise_pred[:, :, 1:, :, :].to(torch.float32) * dt
self.latents = torch.concat([self.image_encoder_output["img_latents"], latents], dim=2)
def prepare_latents(self, shape, dtype):
if self.args.task == "t2v":
def prepare_latents(self, shape, dtype, image_encoder_output):
if self.config.task == "t2v":
self.latents = randn_tensor(shape, generator=self.generator, device=torch.device("cuda"), dtype=dtype)
else:
x1 = self.image_encoder_output["img_latents"].repeat(1, 1, (self.args.target_video_length - 1) // 4 + 1, 1, 1)
x1 = image_encoder_output["img_latents"].repeat(1, 1, (self.config.target_video_length - 1) // 4 + 1, 1, 1)
x0 = randn_tensor(shape, generator=self.generator, device=torch.device("cuda"), dtype=dtype)
t = torch.tensor([0.999]).to(device=torch.device("cuda"))
self.latents = x0 * t + x1 * (1 - t)
self.latents = self.latents.to(dtype=dtype)
self.latents = torch.concat([self.image_encoder_output["img_latents"], self.latents[:, :, 1:, :, :]], dim=2)
self.latents = torch.concat([image_encoder_output["img_latents"], self.latents[:, :, 1:, :, :]], dim=2)
def prepare_rotary_pos_embedding(self, video_length, height, width):
target_ndim = 3
......@@ -305,7 +303,7 @@ class HunyuanScheduler(BaseScheduler):
if len(rope_sizes) != target_ndim:
rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis
if self.args.task == "t2v":
if self.config.task == "t2v":
head_dim = hidden_size // heads_num
rope_dim_list = rope_dim_list
if rope_dim_list is None:
......
......@@ -2,8 +2,8 @@ import torch
class BaseScheduler:
def __init__(self, args):
self.args = args
def __init__(self, config):
self.config = config
self.step_index = 0
self.latents = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment