Commit d02b97a7 authored by wangshankun's avatar wangshankun
Browse files

replace clip model

parent e58dd9fe
......@@ -11,6 +11,9 @@ import torchvision.transforms as T
from lightx2v.attentions import attention
from loguru import logger
from lightx2v.models.input_encoders.hf.q_linear import QuantLinearInt8, QuantLinearFp8
from einops import rearrange
from torch import Tensor
from transformers import CLIPVisionModel
__all__ = [
......@@ -428,3 +431,67 @@ class CLIPModel:
def to_cpu(self):
self.model = self.model.cpu()
class WanVideoIPHandler:
def __init__(self,
model_name,
repo_or_path,
require_grad=False,
mode='eval',
device='cuda',
dtype=torch.float16):
# image_processor = CLIPImageProcessor.from_pretrained(
# repo_or_path, subfolder='image_processor')
''' 720P-I2V-diffusers config is
"size": {
"shortest_edge": 224
}
and 480P-I2V-diffusers config is
"size": {
"height": 224,
"width": 224
}
but Wan2.1 official use no_crop resize by default
so I don't use CLIPImageProcessor
'''
image_encoder = CLIPVisionModel.from_pretrained(
repo_or_path, subfolder='image_encoder', torch_dtype=dtype)
logger.info(
f'Using image encoder {model_name} from {repo_or_path}'
)
image_encoder.requires_grad_(require_grad)
if mode == 'eval':
image_encoder.eval()
else:
image_encoder.train()
self.dtype = dtype
self.device = device
self.image_encoder = image_encoder.to(device=device, dtype=dtype)
self.size = (224, 224)
mean = [0.48145466, 0.4578275, 0.40821073]
std = [0.26862954, 0.26130258, 0.27577711]
self.normalize = T.Normalize(mean=mean, std=std)
# self.image_processor = image_processor
def encode(
self,
img_tensor: Tensor,
):
if img_tensor.ndim == 5: # B C T H W
# img_tensor = img_tensor[:, :, 0]
img_tensor = rearrange(img_tensor, "B C 1 H W -> B C H W")
img_tensor = torch.clamp(
img_tensor.float() * 0.5 + 0.5, min=0.0, max=1.0).to(self.device)
img_tensor = F.interpolate(
img_tensor, size=self.size, mode='bicubic', align_corners=False)
img_tensor = self.normalize(img_tensor).to(self.dtype)
logger.info(
f'Image tensor shape after processing: {img_tensor}')
image_embeds = self.image_encoder(
pixel_values=img_tensor, output_hidden_states=True)
logger.info(
f'Image embeds : {image_embeds.hidden_states[-1]}')
return image_embeds.hidden_states[-1]
\ No newline at end of file
......@@ -376,6 +376,7 @@ class AudioAdapterPipe:
self.device = device
self.generator = generator
self.audio_encoder_dtype = torch.float16
##音频编码器
self.audio_encoder = AutoModel.from_pretrained(audio_encoder_repo)
self.audio_encoder.eval()
......
......@@ -11,7 +11,7 @@ from lightx2v.models.schedulers.wan.scheduler import WanScheduler
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.utils.profiler import ProfilingContext4Debug, ProfilingContext
from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel
from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel
from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel, WanVideoIPHandler
from lightx2v.models.networks.wan.audio_model import WanAudioModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
......@@ -244,7 +244,8 @@ class WanAudioRunner(WanRunner):
super().__init__(config)
def load_audio_models(self):
self.audio_encoder = AutoFeatureExtractor.from_pretrained(self.config["model_path"], subfolder="audio_encoder")
##音频特征提取器
self.audio_preprocess = AutoFeatureExtractor.from_pretrained(self.config["model_path"], subfolder="audio_encoder")
audio_adaper = AudioAdapter.from_transformer(
self.model,
audio_feature_dim=1024,
......@@ -265,6 +266,18 @@ class WanAudioRunner(WanRunner):
base_model = WanAudioModel(self.config.model_path, self.config, self.init_device)
return base_model
def load_image_encoder(self):
image_encoder = WanVideoIPHandler(
"CLIPModel",
repo_or_path="/mnt/aigc/zoemodels/Wan21/Wan2.1-I2V-14B-720P-Diffusers",
require_grad=False,
mode='eval',
device=self.init_device,
dtype=torch.float16)
return image_encoder
def run_image_encoder(self, config, vae_model):
ref_img = Image.open(config.image_path)
ref_img = (np.array(ref_img).astype(np.float32) - 127.5) / 127.5
......@@ -276,8 +289,7 @@ class WanAudioRunner(WanRunner):
cond_frms, tgt_h, tgt_w = adaptive_resize(ref_img)
config.tgt_h = tgt_h
config.tgt_w = tgt_w
clip_encoder_out = self.image_encoder.visual([cond_frms.squeeze(0)[:, None, :, :]], config).squeeze(0).to(torch.bfloat16)
clip_encoder_out = self.image_encoder.encode(cond_frms).squeeze(0).to(torch.bfloat16)
cond_frms = rearrange(cond_frms, "1 C H W -> 1 C 1 H W")
lat_h, lat_w = tgt_h // 8, tgt_w // 8
......@@ -393,7 +405,7 @@ class WanAudioRunner(WanRunner):
if expected_frames < max_num_frames:
useful_length = audio_array.shape[0]
audio_array = np.concatenate((audio_array, np.zeros(max_num_audio_length)[: max_num_audio_length - useful_length]), axis=0)
audio_input_feat = self.audio_encoder(audio_array, sampling_rate=audio_sr, return_tensors="pt").input_values.squeeze(0)
audio_input_feat = self.audio_preprocess(audio_array, sampling_rate=audio_sr, return_tensors="pt").input_values.squeeze(0)
elif res_frame_num > 5 and idx == interval_num - 1: # 最后一段可能不够81帧
prev_frames = torch.zeros((1, 3, max_num_frames, tgt_h, tgt_w), device=device)
......@@ -404,7 +416,7 @@ class WanAudioRunner(WanRunner):
audio_array = audio_array_ori[audio_start:audio_end]
useful_length = audio_array.shape[0]
audio_array = np.concatenate((audio_array, np.zeros(max_num_audio_length)[: max_num_audio_length - useful_length]), axis=0)
audio_input_feat = self.audio_encoder(audio_array, sampling_rate=audio_sr, return_tensors="pt").input_values.squeeze(0)
audio_input_feat = self.audio_preprocess(audio_array, sampling_rate=audio_sr, return_tensors="pt").input_values.squeeze(0)
else: # 中间段满81帧带pre_latens
prev_frames = torch.zeros((1, 3, max_num_frames, tgt_h, tgt_w), device=device)
......@@ -413,7 +425,7 @@ class WanAudioRunner(WanRunner):
prev_len = prev_token_length
audio_start, audio_end = get_audio_range(idx * max_num_frames - idx * prev_frame_length, (idx + 1) * max_num_frames - idx * prev_frame_length, fps=target_fps, audio_sr=audio_sr)
audio_array = audio_array_ori[audio_start:audio_end]
audio_input_feat = self.audio_encoder(audio_array, sampling_rate=audio_sr, return_tensors="pt").input_values.squeeze(0)
audio_input_feat = self.audio_preprocess(audio_array, sampling_rate=audio_sr, return_tensors="pt").input_values.squeeze(0)
self.inputs["audio_encoder_output"] = audio_input_feat.to(device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment