Commit 039456f2 authored by sandy's avatar sandy Committed by GitHub
Browse files

Merge pull request #161 from ModelTC/feat-audio

更新clip预处理
parents 048be946 7e6f9418
...@@ -14,5 +14,6 @@ ...@@ -14,5 +14,6 @@
"sample_shift": 5, "sample_shift": 5,
"enable_cfg": false, "enable_cfg": false,
"cpu_offload": false, "cpu_offload": false,
"adaptive_resize": true "adaptive_resize": true,
"use_31_block": false
} }
...@@ -11,10 +11,6 @@ import torchvision.transforms as T ...@@ -11,10 +11,6 @@ import torchvision.transforms as T
from lightx2v.attentions import attention from lightx2v.attentions import attention
from loguru import logger from loguru import logger
from lightx2v.models.input_encoders.hf.q_linear import VllmQuantLinearInt8, VllmQuantLinearFp8, TorchaoQuantLinearInt8, Q8FQuantLinearInt8, Q8FQuantLinearFp8 from lightx2v.models.input_encoders.hf.q_linear import VllmQuantLinearInt8, VllmQuantLinearFp8, TorchaoQuantLinearInt8, Q8FQuantLinearInt8, Q8FQuantLinearFp8
from einops import rearrange
from torch import Tensor
from transformers import CLIPVisionModel
__all__ = [ __all__ = [
"XLMRobertaCLIP", "XLMRobertaCLIP",
...@@ -448,14 +444,14 @@ class CLIPModel: ...@@ -448,14 +444,14 @@ class CLIPModel:
def visual(self, videos, args): def visual(self, videos, args):
if hasattr(args, "cpu_offload") and args.cpu_offload: if hasattr(args, "cpu_offload") and args.cpu_offload:
self.to_cuda() self.to_cuda()
use_31_block = getattr(args, "use_31_block", True)
# preprocess # preprocess
size = (self.model.image_size,) * 2 size = (self.model.image_size,) * 2
videos = torch.cat([F.interpolate(u.transpose(0, 1), size=size, mode="bicubic", align_corners=False) for u in videos]) videos = torch.cat([F.interpolate(u, size=size, mode="bicubic", align_corners=False) for u in videos])
videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5)) videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
# forward # forward
with torch.amp.autocast("cuda", dtype=self.dtype): with torch.amp.autocast("cuda", dtype=self.dtype):
out = self.model.visual(videos, use_31_block=True) out = self.model.visual(videos, use_31_block=use_31_block)
if hasattr(args, "cpu_offload") and args.cpu_offload: if hasattr(args, "cpu_offload") and args.cpu_offload:
self.to_cpu() self.to_cpu()
...@@ -466,51 +462,3 @@ class CLIPModel: ...@@ -466,51 +462,3 @@ class CLIPModel:
def to_cpu(self): def to_cpu(self):
self.model = self.model.cpu() self.model = self.model.cpu()
class WanVideoIPHandler:
def __init__(self, model_name, repo_or_path, require_grad=False, mode="eval", device="cuda", dtype=torch.float16):
# image_processor = CLIPImageProcessor.from_pretrained(
# repo_or_path, subfolder='image_processor')
"""720P-I2V-diffusers config is
"size": {
"shortest_edge": 224
}
and 480P-I2V-diffusers config is
"size": {
"height": 224,
"width": 224
}
but Wan2.1 official use no_crop resize by default
so I don't use CLIPImageProcessor
"""
image_encoder = CLIPVisionModel.from_pretrained(repo_or_path, torch_dtype=dtype)
logger.info(f"Using image encoder {model_name} from {repo_or_path}")
image_encoder.requires_grad_(require_grad)
if mode == "eval":
image_encoder.eval()
else:
image_encoder.train()
self.dtype = dtype
self.device = device
self.image_encoder = image_encoder.to(device=device, dtype=dtype)
self.size = (224, 224)
mean = [0.48145466, 0.4578275, 0.40821073]
std = [0.26862954, 0.26130258, 0.27577711]
self.normalize = T.Normalize(mean=mean, std=std)
# self.image_processor = image_processor
def encode(
self,
img_tensor: Tensor,
):
if img_tensor.ndim == 5: # B C T H W
# img_tensor = img_tensor[:, :, 0]
img_tensor = rearrange(img_tensor, "B C 1 H W -> B C H W")
img_tensor = torch.clamp(img_tensor.float() * 0.5 + 0.5, min=0.0, max=1.0).to(self.device)
img_tensor = F.interpolate(img_tensor, size=self.size, mode="bicubic", align_corners=False)
img_tensor = self.normalize(img_tensor).to(self.dtype)
image_embeds = self.image_encoder(pixel_values=img_tensor, output_hidden_states=True)
return image_embeds.hidden_states[-1]
...@@ -10,29 +10,18 @@ from dataclasses import dataclass ...@@ -10,29 +10,18 @@ from dataclasses import dataclass
from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.models.runners.wan.wan_runner import WanRunner from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.wan.scheduler import WanScheduler
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.utils.profiler import ProfilingContext4Debug, ProfilingContext from lightx2v.utils.profiler import ProfilingContext4Debug, ProfilingContext
from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel
from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel, WanVideoIPHandler
from lightx2v.models.networks.wan.audio_model import WanAudioModel from lightx2v.models.networks.wan.audio_model import WanAudioModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
from lightx2v.models.networks.wan.audio_adapter import AudioAdapter, AudioAdapterPipe, rank0_load_state_dict_from_path from lightx2v.models.networks.wan.audio_adapter import AudioAdapter, AudioAdapterPipe, rank0_load_state_dict_from_path
from lightx2v.utils.utils import save_to_video, vae_to_comfyui_image from lightx2v.utils.utils import save_to_video, vae_to_comfyui_image
from lightx2v.models.schedulers.wan.audio.scheduler import ConsistencyModelScheduler
from lightx2v.models.schedulers.wan.step_distill.scheduler import WanStepDistillScheduler
from lightx2v.models.schedulers.wan.audio.scheduler import EulerSchedulerTimestepFix, ConsistencyModelScheduler
from loguru import logger from loguru import logger
import torch.distributed as dist
from einops import rearrange from einops import rearrange
import torchaudio as ta import torchaudio as ta
from transformers import AutoFeatureExtractor from transformers import AutoFeatureExtractor
from torchvision.datasets.folder import IMG_EXTENSIONS
from torchvision.transforms import InterpolationMode from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import resize from torchvision.transforms.functional import resize
...@@ -618,12 +607,6 @@ class WanAudioRunner(WanRunner): ...@@ -618,12 +607,6 @@ class WanAudioRunner(WanRunner):
return base_model return base_model
def load_image_encoder(self):
"""Load image encoder"""
clip_model_dir = self.config["model_path"] + "/image_encoder"
image_encoder = WanVideoIPHandler("CLIPModel", repo_or_path=clip_model_dir, require_grad=False, mode="eval", device=self.init_device, dtype=torch.float16)
return image_encoder
def run_image_encoder(self, config, vae_model): def run_image_encoder(self, config, vae_model):
"""Run image encoder""" """Run image encoder"""
...@@ -638,7 +621,7 @@ class WanAudioRunner(WanRunner): ...@@ -638,7 +621,7 @@ class WanAudioRunner(WanRunner):
cond_frms, tgt_h, tgt_w = adaptive_resize(ref_img) cond_frms, tgt_h, tgt_w = adaptive_resize(ref_img)
config.tgt_h = tgt_h config.tgt_h = tgt_h
config.tgt_w = tgt_w config.tgt_w = tgt_w
clip_encoder_out = self.image_encoder.encode(cond_frms).squeeze(0).to(torch.bfloat16) clip_encoder_out = self.image_encoder.visual([cond_frms], self.config).squeeze(0).to(torch.bfloat16)
cond_frms = rearrange(cond_frms, "1 C H W -> 1 C 1 H W") cond_frms = rearrange(cond_frms, "1 C H W -> 1 C 1 H W")
lat_h, lat_w = tgt_h // 8, tgt_w // 8 lat_h, lat_w = tgt_h // 8, tgt_w // 8
...@@ -662,7 +645,7 @@ class WanAudioRunner(WanRunner): ...@@ -662,7 +645,7 @@ class WanAudioRunner(WanRunner):
# Resize image to target size # Resize image to target size
cond_frms = torch.nn.functional.interpolate(ref_img, size=(config.tgt_h, config.tgt_w), mode="bicubic") cond_frms = torch.nn.functional.interpolate(ref_img, size=(config.tgt_h, config.tgt_w), mode="bicubic")
clip_encoder_out = self.image_encoder.encode(cond_frms).squeeze(0).to(torch.bfloat16) clip_encoder_out = self.image_encoder.visual([cond_frms], self.config).squeeze(0).to(torch.bfloat16)
# Prepare for VAE encoding # Prepare for VAE encoding
cond_frms = rearrange(cond_frms, "1 C H W -> 1 C 1 H W") cond_frms = rearrange(cond_frms, "1 C H W -> 1 C 1 H W")
......
...@@ -197,7 +197,7 @@ class WanRunner(DefaultRunner): ...@@ -197,7 +197,7 @@ class WanRunner(DefaultRunner):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.image_encoder = self.load_image_encoder() self.image_encoder = self.load_image_encoder()
img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda() img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
clip_encoder_out = self.image_encoder.visual([img[:, None, :, :]], self.config).squeeze(0).to(torch.bfloat16) clip_encoder_out = self.image_encoder.visual([img[None, :, :, :]], self.config).squeeze(0).to(torch.bfloat16)
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.image_encoder del self.image_encoder
torch.cuda.empty_cache() torch.cuda.empty_cache()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment