Commit 5ec2b691 authored by sandy's avatar sandy Committed by GitHub
Browse files

Merge pull request #135 from ModelTC/audio_r2v

Audio r2v v2
parents 6d07a72e e08c4f90
{ {
"infer_steps": 5, "infer_steps": 4,
"target_fps": 16, "target_fps": 16,
"video_duration": 12, "video_duration": 16,
"audio_sr": 16000, "audio_sr": 16000,
"target_video_length": 81, "target_video_length": 81,
"target_height": 480, "target_height": 480,
...@@ -13,5 +13,6 @@ ...@@ -13,5 +13,6 @@
"sample_guide_scale":1, "sample_guide_scale":1,
"sample_shift": 5, "sample_shift": 5,
"enable_cfg": false, "enable_cfg": false,
"cpu_offload": false "cpu_offload": false,
"use_tiling_vae": true
} }
...@@ -24,13 +24,15 @@ class WanAudioPreInfer(WanPreInfer): ...@@ -24,13 +24,15 @@ class WanAudioPreInfer(WanPreInfer):
self.text_len = config["text_len"] self.text_len = config["text_len"]
def infer(self, weights, inputs, positive): def infer(self, weights, inputs, positive):
ltnt_channel = self.scheduler.latents.size(0) ltnt_frames = self.scheduler.latents.size(1)
prev_latents = inputs["previmg_encoder_output"]["prev_latents"].unsqueeze(0) prev_latents = inputs["previmg_encoder_output"]["prev_latents"].unsqueeze(0)
prev_mask = inputs["previmg_encoder_output"]["prev_mask"] prev_mask = inputs["previmg_encoder_output"]["prev_mask"]
hidden_states = self.scheduler.latents.unsqueeze(0) hidden_states = self.scheduler.latents.unsqueeze(0)
hidden_states = torch.cat([hidden_states[:, :ltnt_channel], prev_latents, prev_mask], dim=1) # hidden_states = torch.cat([hidden_states[:, :ltnt_channel], prev_latents, prev_mask], dim=1)
# print(f"{prev_mask.shape}, {hidden_states.shape}, {prev_latents.shape},{prev_latents[:, :, :ltnt_frames].shape}")
hidden_states = torch.cat([hidden_states, prev_mask, prev_latents[:, :, :ltnt_frames]], dim=1)
hidden_states = hidden_states.squeeze(0) hidden_states = hidden_states.squeeze(0)
x = [hidden_states] x = [hidden_states]
...@@ -44,6 +46,7 @@ class WanAudioPreInfer(WanPreInfer): ...@@ -44,6 +46,7 @@ class WanAudioPreInfer(WanPreInfer):
"timestep": t, "timestep": t,
} }
audio_dit_blocks.append(inputs["audio_adapter_pipe"](**audio_model_input)) audio_dit_blocks.append(inputs["audio_adapter_pipe"](**audio_model_input))
##audio_dit_blocks = None##Debug Drop Audio
if positive: if positive:
context = inputs["text_encoder_output"]["context"] context = inputs["text_encoder_output"]["context"]
......
...@@ -18,6 +18,9 @@ from lightx2v.models.video_encoders.hf.wan.vae import WanVAE ...@@ -18,6 +18,9 @@ from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
from lightx2v.models.networks.wan.audio_adapter import AudioAdapter, AudioAdapterPipe, rank0_load_state_dict_from_path from lightx2v.models.networks.wan.audio_adapter import AudioAdapter, AudioAdapterPipe, rank0_load_state_dict_from_path
from lightx2v.models.schedulers.wan.step_distill.scheduler import WanStepDistillScheduler
from lightx2v.models.schedulers.wan.audio.scheduler import EulerSchedulerTimestepFix
from loguru import logger from loguru import logger
import torch.distributed as dist import torch.distributed as dist
from einops import rearrange from einops import rearrange
...@@ -33,6 +36,45 @@ import warnings ...@@ -33,6 +36,45 @@ import warnings
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
def add_mask_to_frames(
frames: np.ndarray,
mask_rate: float = 0.1,
rnd_state: np.random.RandomState = None,
) -> np.ndarray:
if mask_rate is None:
return frames
if rnd_state is None:
rnd_state = np.random.RandomState()
h, w = frames.shape[-2:]
mask = rnd_state.rand(h, w) > mask_rate
frames = frames * mask
return frames
def add_noise_to_frames(
frames: np.ndarray,
noise_mean: float = -3.0,
noise_std: float = 0.5,
rnd_state: np.random.RandomState = None,
) -> np.ndarray:
if noise_mean is None or noise_std is None:
return frames
if rnd_state is None:
rnd_state = np.random.RandomState()
shape = frames.shape
bs = 1 if len(shape) == 4 else shape[0]
sigma = rnd_state.normal(loc=noise_mean, scale=noise_std, size=(bs,))
sigma = np.exp(sigma)
sigma = np.expand_dims(sigma, axis=tuple(range(1, len(shape))))
noise = rnd_state.randn(*shape) * sigma
frames = frames + noise
return frames
def get_crop_bbox(ori_h, ori_w, tgt_h, tgt_w): def get_crop_bbox(ori_h, ori_w, tgt_h, tgt_w):
tgt_ar = tgt_h / tgt_w tgt_ar = tgt_h / tgt_w
ori_ar = ori_h / ori_w ori_ar = ori_h / ori_w
...@@ -75,7 +117,12 @@ def adaptive_resize(img): ...@@ -75,7 +117,12 @@ def adaptive_resize(img):
aspect_ratios = np.array(np.array(list(bucket_config.keys()))) aspect_ratios = np.array(np.array(list(bucket_config.keys())))
closet_aspect_idx = np.argmin(np.abs(aspect_ratios - ori_ratio)) closet_aspect_idx = np.argmin(np.abs(aspect_ratios - ori_ratio))
closet_ratio = aspect_ratios[closet_aspect_idx] closet_ratio = aspect_ratios[closet_aspect_idx]
if ori_ratio < 1.0:
target_h, target_w = 480, 832 target_h, target_w = 480, 832
elif ori_ratio == 1.0:
target_h, target_w = 480, 480
else:
target_h, target_w = 832, 480
for resolution in bucket_config[closet_ratio][0]: for resolution in bucket_config[closet_ratio][0]:
if ori_height * ori_weight >= resolution[0] * resolution[1]: if ori_height * ori_weight >= resolution[0] * resolution[1]:
target_h, target_w = resolution target_h, target_w = resolution
...@@ -253,6 +300,10 @@ class WanAudioRunner(WanRunner): ...@@ -253,6 +300,10 @@ class WanAudioRunner(WanRunner):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
def init_scheduler(self):
scheduler = EulerSchedulerTimestepFix(self.config)
self.model.set_scheduler(scheduler)
def load_audio_models(self): def load_audio_models(self):
##音频特征提取器 ##音频特征提取器
self.audio_preprocess = AutoFeatureExtractor.from_pretrained(self.config["model_path"], subfolder="audio_encoder") self.audio_preprocess = AutoFeatureExtractor.from_pretrained(self.config["model_path"], subfolder="audio_encoder")
...@@ -372,6 +423,18 @@ class WanAudioRunner(WanRunner): ...@@ -372,6 +423,18 @@ class WanAudioRunner(WanRunner):
audio_frame_rate = audio_sr / fps audio_frame_rate = audio_sr / fps
return round(start_frame * audio_frame_rate), round((end_frame + 1) * audio_frame_rate) return round(start_frame * audio_frame_rate), round((end_frame + 1) * audio_frame_rate)
def wan_mask_rearrange(mask: torch.Tensor):
# mask: 1, T, H, W, where 1 means the input mask is one-channel
if mask.ndim == 3:
mask = mask[None]
assert mask.ndim == 4
_, t, h, w = mask.shape
assert t == ((t - 1) // 4 * 4 + 1)
mask_first_frame = torch.repeat_interleave(mask[:, 0:1], repeats=4, dim=1)
mask = torch.concat([mask_first_frame, mask[:, 1:]], dim=1)
mask = mask.view(mask.shape[1] // 4, 4, h, w)
return mask.transpose(0, 1) # 4, T // 4, H, W
self.inputs["audio_adapter_pipe"] = self.load_audio_models() self.inputs["audio_adapter_pipe"] = self.load_audio_models()
# process audio # process audio
...@@ -427,7 +490,14 @@ class WanAudioRunner(WanRunner): ...@@ -427,7 +490,14 @@ class WanAudioRunner(WanRunner):
elif res_frame_num > 5 and idx == interval_num - 1: # 最后一段可能不够81帧 elif res_frame_num > 5 and idx == interval_num - 1: # 最后一段可能不够81帧
prev_frames = torch.zeros((1, 3, max_num_frames, tgt_h, tgt_w), device=device) prev_frames = torch.zeros((1, 3, max_num_frames, tgt_h, tgt_w), device=device)
prev_frames[:, :, :prev_frame_length] = gen_video_list[-1][:, :, -prev_frame_length:] last_frames = gen_video_list[-1][:, :, -prev_frame_length:].clone().to(device)
last_frames = last_frames.cpu().detach().numpy()
last_frames = add_noise_to_frames(last_frames)
last_frames = add_mask_to_frames(last_frames, mask_rate=0.1) # mask 0.10
last_frames = torch.from_numpy(last_frames).to(dtype=dtype, device=device)
prev_frames[:, :, :prev_frame_length] = last_frames
prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype), self.config)[0].to(dtype) prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype), self.config)[0].to(dtype)
prev_len = prev_token_length prev_len = prev_token_length
audio_start, audio_end = get_audio_range(idx * max_num_frames - idx * prev_frame_length, expected_frames, fps=target_fps, audio_sr=audio_sr) audio_start, audio_end = get_audio_range(idx * max_num_frames - idx * prev_frame_length, expected_frames, fps=target_fps, audio_sr=audio_sr)
...@@ -438,7 +508,14 @@ class WanAudioRunner(WanRunner): ...@@ -438,7 +508,14 @@ class WanAudioRunner(WanRunner):
else: # 中间段满81帧带pre_latens else: # 中间段满81帧带pre_latens
prev_frames = torch.zeros((1, 3, max_num_frames, tgt_h, tgt_w), device=device) prev_frames = torch.zeros((1, 3, max_num_frames, tgt_h, tgt_w), device=device)
prev_frames[:, :, :prev_frame_length] = gen_video_list[-1][:, :, -prev_frame_length:] last_frames = gen_video_list[-1][:, :, -prev_frame_length:].clone().to(device)
last_frames = last_frames.cpu().detach().numpy()
last_frames = add_noise_to_frames(last_frames)
last_frames = add_mask_to_frames(last_frames, mask_rate=0.1) # mask 0.10
last_frames = torch.from_numpy(last_frames).to(dtype=dtype, device=device)
prev_frames[:, :, :prev_frame_length] = last_frames
prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype), self.config)[0].to(dtype) prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype), self.config)[0].to(dtype)
prev_len = prev_token_length prev_len = prev_token_length
audio_start, audio_end = get_audio_range(idx * max_num_frames - idx * prev_frame_length, (idx + 1) * max_num_frames - idx * prev_frame_length, fps=target_fps, audio_sr=audio_sr) audio_start, audio_end = get_audio_range(idx * max_num_frames - idx * prev_frame_length, (idx + 1) * max_num_frames - idx * prev_frame_length, fps=target_fps, audio_sr=audio_sr)
...@@ -452,11 +529,11 @@ class WanAudioRunner(WanRunner): ...@@ -452,11 +529,11 @@ class WanAudioRunner(WanRunner):
if prev_latents is not None: if prev_latents is not None:
ltnt_channel, nframe, height, width = self.model.scheduler.latents.shape ltnt_channel, nframe, height, width = self.model.scheduler.latents.shape
bs = 1 # bs = 1
prev_mask = torch.zeros((bs, 1, nframe, height, width), device=device, dtype=dtype) frames_n = (nframe - 1) * 4 + 1
if prev_len > 0: prev_mask = torch.zeros((1, frames_n, height, width), device=device, dtype=dtype)
prev_mask[:, :, :prev_len] = 1.0 prev_mask[:, prev_len:] = 0
prev_mask = wan_mask_rearrange(prev_mask).unsqueeze(0)
previmg_encoder_output = { previmg_encoder_output = {
"prev_latents": prev_latents, "prev_latents": prev_latents,
"prev_mask": prev_mask, "prev_mask": prev_mask,
...@@ -483,13 +560,13 @@ class WanAudioRunner(WanRunner): ...@@ -483,13 +560,13 @@ class WanAudioRunner(WanRunner):
start_audio_frame = 0 if idx == 0 else int((prev_frame_length + 1) * audio_sr / target_fps) start_audio_frame = 0 if idx == 0 else int((prev_frame_length + 1) * audio_sr / target_fps)
if res_frame_num > 5 and idx == interval_num - 1: if res_frame_num > 5 and idx == interval_num - 1:
gen_video_list.append(gen_video[:, :, start_frame:res_frame_num]) gen_video_list.append(gen_video[:, :, start_frame:res_frame_num].cpu())
cut_audio_list.append(audio_array[start_audio_frame:useful_length]) cut_audio_list.append(audio_array[start_audio_frame:useful_length])
elif expected_frames < max_num_frames and useful_length != -1: elif expected_frames < max_num_frames and useful_length != -1:
gen_video_list.append(gen_video[:, :, start_frame:expected_frames]) gen_video_list.append(gen_video[:, :, start_frame:expected_frames].cpu())
cut_audio_list.append(audio_array[start_audio_frame:useful_length]) cut_audio_list.append(audio_array[start_audio_frame:useful_length])
else: else:
gen_video_list.append(gen_video[:, :, start_frame:]) gen_video_list.append(gen_video[:, :, start_frame:].cpu())
cut_audio_list.append(audio_array[start_audio_frame:]) cut_audio_list.append(audio_array[start_audio_frame:])
gen_lvideo = torch.cat(gen_video_list, dim=2).float() gen_lvideo = torch.cat(gen_video_list, dim=2).float()
......
import os
import gc
import math
import numpy as np
import torch
from typing import List, Optional, Tuple, Union
from lightx2v.utils.envs import *
from diffusers.configuration_utils import register_to_config
from torch import Tensor
from .utils import unsqueeze_to_ndim
from diffusers import (
FlowMatchEulerDiscreteScheduler as FlowMatchEulerDiscreteSchedulerBase, # pyright: ignore
)
def get_timesteps(num_steps, max_steps: int = 1000):
return np.linspace(max_steps, 0, num_steps + 1, dtype=np.float32)
def timestep_shift(timesteps, shift: float = 1.0):
return shift * timesteps / (1 + (shift - 1) * timesteps)
class FlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteSchedulerBase):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.init_noise_sigma = 1.0
def add_noise(self, x0: Tensor, noise: Tensor, timesteps: Tensor):
dtype = x0.dtype
device = x0.device
sigma = timesteps.to(device, torch.float32) / self.config.num_train_timesteps
sigma = unsqueeze_to_ndim(sigma, x0.ndim)
xt = x0.float() * (1 - sigma) + noise.float() * sigma
return xt.to(dtype)
def get_velocity(self, x0: Tensor, noise: Tensor, timesteps: Tensor | None = None):
return noise - x0
def velocity_loss_to_x_loss(self, v_loss: Tensor, timesteps: Tensor):
device = v_loss.device
sigma = timesteps.to(device, torch.float32) / self.config.num_train_timesteps
return v_loss.float() * (sigma**2)
class EulerSchedulerTimestepFix(FlowMatchEulerDiscreteScheduler):
def __init__(self, config):
self.config = config
self.step_index = 0
self.latents = None
self.caching_records = [True] * config.infer_steps
self.flag_df = False
self.transformer_infer = None
self.device = torch.device("cuda")
self.infer_steps = self.config.infer_steps
self.target_video_length = self.config.target_video_length
self.sample_shift = self.config.sample_shift
self.num_train_timesteps = 1000
self.noise_pred = None
def step_pre(self, step_index):
self.step_index = step_index
if GET_DTYPE() == "BF16":
self.latents = self.latents.to(dtype=torch.bfloat16)
def set_shift(self, shift: float = 1.0):
self.sigmas = self.timesteps_ori / self.num_train_timesteps
self.sigmas = timestep_shift(self.sigmas, shift=shift)
self.timesteps = self.sigmas * self.num_train_timesteps
def set_timesteps(
self,
infer_steps: Union[int, None] = None,
device: Union[str, torch.device] = None,
sigmas: Optional[List[float]] = None,
mu: Optional[Union[float, None]] = None,
shift: Optional[Union[float, None]] = None,
):
timesteps = get_timesteps(num_steps=infer_steps, max_steps=self.num_train_timesteps)
self.timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device or self.device)
self.timesteps_ori = self.timesteps.clone()
self.set_shift(self.sample_shift)
self._step_index = None
self._begin_index = None
def prepare(self, image_encoder_output=None):
self.generator = torch.Generator(device=self.device)
self.generator.manual_seed(self.config.seed)
self.prepare_latents(self.config.target_shape, dtype=torch.float32)
if os.path.isfile(self.config.image_path):
self.seq_len = ((self.config.target_video_length - 1) // self.config.vae_stride[0] + 1) * self.config.lat_h * self.config.lat_w // (self.config.patch_size[1] * self.config.patch_size[2])
else:
self.seq_len = math.ceil((self.config.target_shape[2] * self.config.target_shape[3]) / (self.config.patch_size[1] * self.config.patch_size[2]) * self.config.target_shape[1])
self.model_outputs = [None] * self.solver_order
self.timestep_list = [None] * self.solver_order
self.last_sample = None
self.set_timesteps(infer_steps=self.infer_steps, device=self.device, shift=self.sample_shift)
def prepare_latents(self, target_shape, dtype=torch.float32):
self.latents = (
torch.randn(
target_shape[0],
target_shape[1],
target_shape[2],
target_shape[3],
dtype=dtype,
device=self.device,
generator=self.generator,
)
* self.init_noise_sigma
)
def step_post(self):
model_output = self.noise_pred.to(torch.float32)
timestep = self.timesteps[self.step_index]
sample = self.latents.to(torch.float32)
if self.step_index is None:
self._init_step_index(timestep)
sample = sample.to(torch.float32) # pyright: ignore
sigma = unsqueeze_to_ndim(self.sigmas[self.step_index], sample.ndim)
sigma_next = unsqueeze_to_ndim(self.sigmas[self.step_index + 1], sample.ndim)
# x0 = sample - model_output * sigma
x_t_next = sample + (sigma_next - sigma) * model_output
self._step_index += 1
return x_t_next
def reset(self):
self.model_outputs = [None] * self.solver_order
self.timestep_list = [None] * self.solver_order
self.last_sample = None
self.noise_pred = None
self.this_order = None
self.lower_order_nums = 0
self.prepare_latents(self.config.target_shape, dtype=torch.float32)
gc.collect()
torch.cuda.empty_cache()
import os
import gc
import math
import numpy as np
import torch
from typing import List, Optional, Tuple, Union
from lightx2v.utils.envs import *
from lightx2v.models.schedulers.scheduler import BaseScheduler
from loguru import logger
from diffusers.configuration_utils import register_to_config
from torch import Tensor
from diffusers import (
FlowMatchEulerDiscreteScheduler as FlowMatchEulerDiscreteSchedulerBase, # pyright: ignore
)
def unsqueeze_to_ndim(in_tensor: Tensor, tgt_n_dim: int):
if in_tensor.ndim > tgt_n_dim:
warnings.warn(f"the given tensor of shape {in_tensor.shape} is expected to unsqueeze to {tgt_n_dim}, the original tensor will be returned")
return in_tensor
if in_tensor.ndim < tgt_n_dim:
in_tensor = in_tensor[(...,) + (None,) * (tgt_n_dim - in_tensor.ndim)]
return in_tensor
class EulerSchedulerTimestepFix(BaseScheduler):
def __init__(self, config, **kwargs):
# super().__init__(**kwargs)
self.init_noise_sigma = 1.0
self.config = config
self.latents = None
self.caching_records = [True] * config.infer_steps
self.flag_df = False
self.transformer_infer = None
self.solver_order = 2
self.device = torch.device("cuda")
self.infer_steps = self.config.infer_steps
self.target_video_length = self.config.target_video_length
self.sample_shift = self.config.sample_shift
self.shift = 1
self.num_train_timesteps = 1000
self.step_index = None
self.noise_pred = None
self._step_index = None
self._begin_index = None
def step_pre(self, step_index):
self.step_index = step_index
if GET_DTYPE() == "BF16":
self.latents = self.latents.to(dtype=torch.bfloat16)
def set_timesteps(
self,
infer_steps: Union[int, None] = None,
device: Union[str, torch.device] = None,
sigmas: Optional[List[float]] = None,
mu: Optional[Union[float, None]] = None,
shift: Optional[Union[float, None]] = None,
):
sigmas = np.linspace(self.sigma_max, self.sigma_min, infer_steps + 1).copy()[:-1]
if shift is None:
shift = self.shift
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
sigma_last = 0
timesteps = sigmas * self.num_train_timesteps
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas)
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
assert len(self.timesteps) == self.infer_steps
self.model_outputs = [
None,
] * self.solver_order
self.lower_order_nums = 0
self.last_sample = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu")
def prepare(self, image_encoder_output=None):
self.generator = torch.Generator(device=self.device)
self.generator.manual_seed(self.config.seed)
self.prepare_latents(self.config.target_shape, dtype=torch.float32)
if self.config.task in ["t2v"]:
self.seq_len = math.ceil((self.config.target_shape[2] * self.config.target_shape[3]) / (self.config.patch_size[1] * self.config.patch_size[2]) * self.config.target_shape[1])
elif self.config.task in ["i2v"]:
self.seq_len = ((self.config.target_video_length - 1) // self.config.vae_stride[0] + 1) * self.config.lat_h * self.config.lat_w // (self.config.patch_size[1] * self.config.patch_size[2])
alphas = np.linspace(1, 1 / self.num_train_timesteps, self.num_train_timesteps)[::-1].copy()
sigmas = 1.0 - alphas
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32)
sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)
self.sigmas = sigmas
self.timesteps = sigmas * self.num_train_timesteps
self.model_outputs = [None] * self.solver_order
self.timestep_list = [None] * self.solver_order
self.last_sample = None
self.sigmas = self.sigmas.to("cpu")
self.sigma_min = self.sigmas[-1].item()
self.sigma_max = self.sigmas[0].item()
self.set_timesteps(self.infer_steps, device=self.device, shift=self.sample_shift)
def prepare_latents(self, target_shape, dtype=torch.float32):
self.latents = (
torch.randn(
target_shape[0],
target_shape[1],
target_shape[2],
target_shape[3],
dtype=dtype,
device=self.device,
generator=self.generator,
)
* self.init_noise_sigma
)
def step_post(self):
model_output = self.noise_pred.to(torch.float32)
sample = self.latents.to(torch.float32)
sample = sample.to(torch.float32)
sigma = unsqueeze_to_ndim(self.sigmas[self.step_index], sample.ndim).to(sample.device, sample.dtype)
sigma_next = unsqueeze_to_ndim(self.sigmas[self.step_index + 1], sample.ndim).to(sample.device, sample.dtype)
# x0 = sample - model_output * sigma
x_t_next = sample + (sigma_next - sigma) * model_output
self.latents = x_t_next
def reset(self):
self.model_outputs = [None]
self.timestep_list = [None]
self.last_sample = None
self.noise_pred = None
self.this_order = None
self.lower_order_nums = 0
self.prepare_latents(self.config.target_shape, dtype=torch.float32)
gc.collect()
torch.cuda.empty_cache()
#!/bin/bash #!/bin/bash
# set path and first # set path and first
lightx2v_path= lightx2v_path=
model_path= model_path=
lora_path= lora_path=
# check section # check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0 cuda_devices=0
...@@ -42,5 +44,4 @@ python -m lightx2v.infer \ ...@@ -42,5 +44,4 @@ python -m lightx2v.infer \
--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \ --negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--image_path ${lightx2v_path}/assets/inputs/audio/15.png \ --image_path ${lightx2v_path}/assets/inputs/audio/15.png \
--audio_path ${lightx2v_path}/assets/inputs/audio/15.wav \ --audio_path ${lightx2v_path}/assets/inputs/audio/15.wav \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v_audio.mp4 \ --save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v_audio.mp4
--lora_path ${lora_path}
...@@ -5,12 +5,13 @@ ...@@ -5,12 +5,13 @@
### ViGen-DiT Project Url: https://github.com/yl-1993/ViGen-DiT ### ViGen-DiT Project Url: https://github.com/yl-1993/ViGen-DiT
### ###
import torch import torch
from safetensors.torch import save_file
import sys import sys
import os import os
from safetensors.torch import save_file
from safetensors.torch import load_file
if len(sys.argv) != 3: if len(sys.argv) != 3:
print("用法: python convert_lora.py <输入文件.pt> <输出文件.safetensors>") print("用法: python convert_lora.py <输入文件> <输出文件.safetensors>")
sys.exit(1) sys.exit(1)
ckpt_path = sys.argv[1] ckpt_path = sys.argv[1]
...@@ -20,7 +21,10 @@ if not os.path.exists(ckpt_path): ...@@ -20,7 +21,10 @@ if not os.path.exists(ckpt_path):
print(f"❌ 输入文件不存在: {ckpt_path}") print(f"❌ 输入文件不存在: {ckpt_path}")
sys.exit(1) sys.exit(1)
state_dict = torch.load(ckpt_path, map_location="cpu") if ckpt_path.endswith(".safetensors"):
state_dict = load_file(ckpt_path)
else:
state_dict = torch.load(ckpt_path, map_location="cpu")
if "state_dict" in state_dict: if "state_dict" in state_dict:
state_dict = state_dict["state_dict"] state_dict = state_dict["state_dict"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment