Commit 2faef15a authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Merge pull request #30 from ModelTC/dev_quant

Support load safetensors format quant weights.
parents ca696d83 0c6736d3
......@@ -12,7 +12,9 @@ from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer
from lightx2v.models.networks.wan.infer.transformer_infer import (
WanTransformerInfer,
)
from lightx2v.models.networks.wan.infer.feature_caching.transformer_infer import WanTransformerInferTeaCaching
from lightx2v.models.networks.wan.infer.feature_caching.transformer_infer import (
WanTransformerInferTeaCaching,
)
from safetensors import safe_open
import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap
import lightx2v.attentions.distributed.ring.wrap as ring_dist_wrap
......@@ -83,9 +85,32 @@ class WanModel:
def _load_ckpt_quant_model(self):
assert self.config.get("naive_quant_path") is not None, "naive_quant_path is None"
logger.info(f"Loading quant model from {self.config.naive_quant_path}")
quant_weights_path = os.path.join(self.config.naive_quant_path, "quant_weights.pth")
weight_dict = torch.load(quant_weights_path, map_location=self.device, weights_only=True)
ckpt_path = self.config.naive_quant_path
logger.info(f"Loading quant model from {ckpt_path}")
quant_pth_file = os.path.join(ckpt_path, "quant_weights.pth")
if os.path.exists(quant_pth_file):
logger.info("Found quant_weights.pth, loading as PyTorch model.")
weight_dict = torch.load(quant_pth_file, map_location=self.device, weights_only=True)
else:
index_files = [f for f in os.listdir(ckpt_path) if f.endswith(".index.json")]
if not index_files:
raise FileNotFoundError(f"No quant_weights.pth or *.index.json found in {ckpt_path}")
index_path = os.path.join(ckpt_path, index_files[0])
logger.info(f"quant_weights.pth not found. Using safetensors index: {index_path}")
with open(index_path, "r") as f:
index_data = json.load(f)
weight_dict = {}
for filename in set(index_data["weight_map"].values()):
safetensor_path = os.path.join(ckpt_path, filename)
logger.info(f"Loading weights from {safetensor_path}")
partial_weights = load_file(safetensor_path, device=self.device)
weight_dict.update(partial_weights)
return weight_dict
def _init_weights(self, weight_dict=None):
......
......@@ -82,9 +82,9 @@ class DefaultRunner:
def save_video(self, images):
if not self.config.parallel_attn_type or (self.config.parallel_attn_type and dist.get_rank() == 0):
if self.config.model_cls in ["wan2.1", "wan2.1_causvid", "wan2.1_skyreels_v2_df"]:
cache_video(tensor=images, save_file=self.config.save_video_path, fps=16, nrow=1, normalize=True, value_range=(-1, 1))
cache_video(tensor=images, save_file=self.config.save_video_path, fps=self.config.get("fps", 16), nrow=1, normalize=True, value_range=(-1, 1))
else:
save_videos_grid(images, self.config.save_video_path, fps=24)
save_videos_grid(images, self.config.save_video_path, fps=self.config.get("fps", 24))
def run_pipeline(self):
if self.config["use_prompt_enhancer"]:
......
......@@ -6,7 +6,9 @@ from PIL import Image
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.wan.scheduler import WanScheduler
from lightx2v.models.schedulers.wan.feature_caching.scheduler import WanSchedulerTeaCaching
from lightx2v.models.schedulers.wan.feature_caching.scheduler import (
WanSchedulerTeaCaching,
)
from lightx2v.utils.profiler import ProfilingContext
from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel
from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel
......@@ -50,12 +52,19 @@ class WanRunner(DefaultRunner):
lora_wrapper.apply_lora(lora_name, self.config.strength_model)
logger.info(f"Loaded LoRA: {lora_name}")
vae_model = WanVAE(vae_pth=os.path.join(self.config.model_path, "Wan2.1_VAE.pth"), device=init_device, parallel=self.config.parallel_vae)
vae_model = WanVAE(
vae_pth=os.path.join(self.config.model_path, "Wan2.1_VAE.pth"),
device=init_device,
parallel=self.config.parallel_vae,
)
if self.config.task == "i2v":
image_encoder = CLIPModel(
dtype=torch.float16,
device=init_device,
checkpoint_path=os.path.join(self.config.model_path, "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"),
checkpoint_path=os.path.join(
self.config.model_path,
"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
),
tokenizer_path=os.path.join(self.config.model_path, "xlm-roberta-large"),
)
......@@ -94,23 +103,38 @@ class WanRunner(DefaultRunner):
config.lat_h = lat_h
config.lat_w = lat_w
msk = torch.ones(1, 81, lat_h, lat_w, device=torch.device("cuda"))
msk = torch.ones(1, config.target_video_length, lat_h, lat_w, device=torch.device("cuda"))
msk[:, 1:] = 0
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
msk = msk.transpose(1, 2)[0]
vae_encode_out = vae_model.encode(
[torch.concat([torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode="bicubic").transpose(0, 1), torch.zeros(3, 80, h, w)], dim=1).cuda()], config
[
torch.concat(
[
torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode="bicubic").transpose(0, 1),
torch.zeros(3, 80, h, w),
],
dim=1,
).cuda()
],
config,
)[0]
vae_encode_out = torch.concat([msk, vae_encode_out]).to(torch.bfloat16)
return {"clip_encoder_out": clip_encoder_out, "vae_encode_out": vae_encode_out}
def set_target_shape(self):
num_channels_latents = self.config.get("num_channels_latents", 16)
if self.config.task == "i2v":
self.config.target_shape = (16, 21, self.config.lat_h, self.config.lat_w)
self.config.target_shape = (
num_channels_latents,
(self.config.target_video_length - 1) // 4 + 1,
self.config.lat_h,
self.config.lat_w,
)
elif self.config.task == "t2v":
self.config.target_shape = (
16,
num_channels_latents,
(self.config.target_video_length - 1) // 4 + 1,
int(self.config.target_height) // self.config.vae_stride[1],
int(self.config.target_width) // self.config.vae_stride[2],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment