Commit 039456f2 authored by sandy's avatar sandy Committed by GitHub
Browse files

Merge pull request #161 from ModelTC/feat-audio

更新clip预处理
parents 048be946 7e6f9418
......@@ -14,5 +14,6 @@
"sample_shift": 5,
"enable_cfg": false,
"cpu_offload": false,
"adaptive_resize": true
"adaptive_resize": true,
"use_31_block": false
}
......@@ -11,10 +11,6 @@ import torchvision.transforms as T
from lightx2v.attentions import attention
from loguru import logger
from lightx2v.models.input_encoders.hf.q_linear import VllmQuantLinearInt8, VllmQuantLinearFp8, TorchaoQuantLinearInt8, Q8FQuantLinearInt8, Q8FQuantLinearFp8
from einops import rearrange
from torch import Tensor
from transformers import CLIPVisionModel
__all__ = [
"XLMRobertaCLIP",
......@@ -448,14 +444,14 @@ class CLIPModel:
def visual(self, videos, args):
if hasattr(args, "cpu_offload") and args.cpu_offload:
self.to_cuda()
use_31_block = getattr(args, "use_31_block", True)
# preprocess
size = (self.model.image_size,) * 2
videos = torch.cat([F.interpolate(u.transpose(0, 1), size=size, mode="bicubic", align_corners=False) for u in videos])
videos = torch.cat([F.interpolate(u, size=size, mode="bicubic", align_corners=False) for u in videos])
videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
# forward
with torch.amp.autocast("cuda", dtype=self.dtype):
out = self.model.visual(videos, use_31_block=True)
out = self.model.visual(videos, use_31_block=use_31_block)
if hasattr(args, "cpu_offload") and args.cpu_offload:
self.to_cpu()
......@@ -466,51 +462,3 @@ class CLIPModel:
def to_cpu(self):
self.model = self.model.cpu()
class WanVideoIPHandler:
def __init__(self, model_name, repo_or_path, require_grad=False, mode="eval", device="cuda", dtype=torch.float16):
# image_processor = CLIPImageProcessor.from_pretrained(
# repo_or_path, subfolder='image_processor')
"""720P-I2V-diffusers config is
"size": {
"shortest_edge": 224
}
and 480P-I2V-diffusers config is
"size": {
"height": 224,
"width": 224
}
but Wan2.1 official use no_crop resize by default
so I don't use CLIPImageProcessor
"""
image_encoder = CLIPVisionModel.from_pretrained(repo_or_path, torch_dtype=dtype)
logger.info(f"Using image encoder {model_name} from {repo_or_path}")
image_encoder.requires_grad_(require_grad)
if mode == "eval":
image_encoder.eval()
else:
image_encoder.train()
self.dtype = dtype
self.device = device
self.image_encoder = image_encoder.to(device=device, dtype=dtype)
self.size = (224, 224)
mean = [0.48145466, 0.4578275, 0.40821073]
std = [0.26862954, 0.26130258, 0.27577711]
self.normalize = T.Normalize(mean=mean, std=std)
# self.image_processor = image_processor
def encode(
self,
img_tensor: Tensor,
):
if img_tensor.ndim == 5: # B C T H W
# img_tensor = img_tensor[:, :, 0]
img_tensor = rearrange(img_tensor, "B C 1 H W -> B C H W")
img_tensor = torch.clamp(img_tensor.float() * 0.5 + 0.5, min=0.0, max=1.0).to(self.device)
img_tensor = F.interpolate(img_tensor, size=self.size, mode="bicubic", align_corners=False)
img_tensor = self.normalize(img_tensor).to(self.dtype)
image_embeds = self.image_encoder(pixel_values=img_tensor, output_hidden_states=True)
return image_embeds.hidden_states[-1]
......@@ -10,29 +10,18 @@ from dataclasses import dataclass
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.wan.scheduler import WanScheduler
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.utils.profiler import ProfilingContext4Debug, ProfilingContext
from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel
from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel, WanVideoIPHandler
from lightx2v.models.networks.wan.audio_model import WanAudioModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
from lightx2v.models.networks.wan.audio_adapter import AudioAdapter, AudioAdapterPipe, rank0_load_state_dict_from_path
from lightx2v.utils.utils import save_to_video, vae_to_comfyui_image
from lightx2v.models.schedulers.wan.step_distill.scheduler import WanStepDistillScheduler
from lightx2v.models.schedulers.wan.audio.scheduler import EulerSchedulerTimestepFix, ConsistencyModelScheduler
from lightx2v.models.schedulers.wan.audio.scheduler import ConsistencyModelScheduler
from loguru import logger
import torch.distributed as dist
from einops import rearrange
import torchaudio as ta
from transformers import AutoFeatureExtractor
from torchvision.datasets.folder import IMG_EXTENSIONS
from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import resize
......@@ -618,12 +607,6 @@ class WanAudioRunner(WanRunner):
return base_model
def load_image_encoder(self):
"""Load image encoder"""
clip_model_dir = self.config["model_path"] + "/image_encoder"
image_encoder = WanVideoIPHandler("CLIPModel", repo_or_path=clip_model_dir, require_grad=False, mode="eval", device=self.init_device, dtype=torch.float16)
return image_encoder
def run_image_encoder(self, config, vae_model):
"""Run image encoder"""
......@@ -638,7 +621,7 @@ class WanAudioRunner(WanRunner):
cond_frms, tgt_h, tgt_w = adaptive_resize(ref_img)
config.tgt_h = tgt_h
config.tgt_w = tgt_w
clip_encoder_out = self.image_encoder.encode(cond_frms).squeeze(0).to(torch.bfloat16)
clip_encoder_out = self.image_encoder.visual([cond_frms], self.config).squeeze(0).to(torch.bfloat16)
cond_frms = rearrange(cond_frms, "1 C H W -> 1 C 1 H W")
lat_h, lat_w = tgt_h // 8, tgt_w // 8
......@@ -662,7 +645,7 @@ class WanAudioRunner(WanRunner):
# Resize image to target size
cond_frms = torch.nn.functional.interpolate(ref_img, size=(config.tgt_h, config.tgt_w), mode="bicubic")
clip_encoder_out = self.image_encoder.encode(cond_frms).squeeze(0).to(torch.bfloat16)
clip_encoder_out = self.image_encoder.visual([cond_frms], self.config).squeeze(0).to(torch.bfloat16)
# Prepare for VAE encoding
cond_frms = rearrange(cond_frms, "1 C H W -> 1 C 1 H W")
......
......@@ -197,7 +197,7 @@ class WanRunner(DefaultRunner):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.image_encoder = self.load_image_encoder()
img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
clip_encoder_out = self.image_encoder.visual([img[:, None, :, :]], self.config).squeeze(0).to(torch.bfloat16)
clip_encoder_out = self.image_encoder.visual([img[None, :, :, :]], self.config).squeeze(0).to(torch.bfloat16)
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.image_encoder
torch.cuda.empty_cache()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment