"torchvision/vscode:/vscode.git/clone" did not exist on "6518372ea80a14fc2ca07758df83528dc8973a71"
Unverified Commit b20ec092 authored by sandy's avatar sandy Committed by GitHub
Browse files

[Feat] Support video super resolution (#385)

parent 62789aa4
{
"infer_steps": 2,
"target_fps": 25,
"video_duration": 1,
"audio_sr": 16000,
"target_video_length": 25,
"resize_mode": "fixed_shape",
"fixed_shape": [
192,
320
],
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"sample_guide_scale": 1.0,
"sample_shift": 5,
"enable_cfg": false,
"cpu_offload": false,
"use_31_block": false,
"video_super_resolution": {
"scale": 2.0,
"seed": 0,
"model_path": "/base_code/FlashVSR/examples/WanVSR/FlashVSR"
}
}
......@@ -72,6 +72,15 @@ class DefaultRunner(BaseRunner):
else:
raise ValueError(f"Unsupported VFI model: {self.config['video_frame_interpolation']['algo']}")
def load_vsr_model(self):
if "video_super_resolution" in self.config:
from lightx2v.models.runners.vsr.vsr_wrapper import VSRWrapper
logger.info("Loading VSR model...")
return VSRWrapper(self.config["video_super_resolution"]["model_path"])
else:
return None
@ProfilingContext4DebugL2("Load models")
def load_model(self):
self.model = self.load_transformer()
......@@ -79,6 +88,7 @@ class DefaultRunner(BaseRunner):
self.image_encoder = self.load_image_encoder()
self.vae_encoder, self.vae_decoder = self.load_vae()
self.vfi_model = self.load_vfi_model() if "video_frame_interpolation" in self.config else None
self.vsr_model = self.load_vsr_model() if "video_super_resolution" in self.config else None
def check_sub_servers(self, task_type):
urls = self.config.get("sub_servers", {}).get(task_type, [])
......
#!/usr/bin/env python3
"""
Tiny AutoEncoder for Hunyuan Video (Decoder-only, pruned)
- Encoder removed
- Transplant/widening helpers removed
- Deepening (IdentityConv2d+ReLU) is now built into the decoder structure itself
"""
from collections import namedtuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from einops import rearrange
from tqdm.auto import tqdm
DecoderResult = namedtuple("DecoderResult", ("frame", "memory"))
TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index"))
# ----------------------------
# Utility / building blocks
# ----------------------------
class IdentityConv2d(nn.Conv2d):
"""Same-shape Conv2d initialized to identity (Dirac)."""
def __init__(self, C, kernel_size=3, bias=False):
pad = kernel_size // 2
super().__init__(C, C, kernel_size, padding=pad, bias=bias)
with torch.no_grad():
init.dirac_(self.weight)
if self.bias is not None:
self.bias.zero_()
def conv(n_in, n_out, **kwargs):
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
class Clamp(nn.Module):
def forward(self, x):
return torch.tanh(x / 3) * 3
class MemBlock(nn.Module):
def __init__(self, n_in, n_out):
super().__init__()
self.conv = nn.Sequential(conv(n_in * 2, n_out), nn.ReLU(inplace=True), conv(n_out, n_out), nn.ReLU(inplace=True), conv(n_out, n_out))
self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
self.act = nn.ReLU(inplace=True)
def forward(self, x, past):
return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x))
class TPool(nn.Module):
def __init__(self, n_f, stride):
super().__init__()
self.stride = stride
self.conv = nn.Conv2d(n_f * stride, n_f, 1, bias=False)
def forward(self, x):
_NT, C, H, W = x.shape
return self.conv(x.reshape(-1, self.stride * C, H, W))
class TGrow(nn.Module):
def __init__(self, n_f, stride):
super().__init__()
self.stride = stride
self.conv = nn.Conv2d(n_f, n_f * stride, 1, bias=False)
def forward(self, x):
_NT, C, H, W = x.shape
x = self.conv(x)
return x.reshape(-1, C, H, W)
class PixelShuffle3d(nn.Module):
def __init__(self, ff, hh, ww):
super().__init__()
self.ff = ff
self.hh = hh
self.ww = ww
def forward(self, x):
# x: (B, C, F, H, W)
B, C, F, H, W = x.shape
if F % self.ff != 0:
first_frame = x[:, :, 0:1, :, :].repeat(1, 1, self.ff - F % self.ff, 1, 1)
x = torch.cat([first_frame, x], dim=2)
return rearrange(x, "b c (f ff) (h hh) (w ww) -> b (c ff hh ww) f h w", ff=self.ff, hh=self.hh, ww=self.ww).transpose(1, 2)
# ----------------------------
# Generic NTCHW graph executor (kept; used by decoder)
# ----------------------------
def apply_model_with_memblocks(model, x, parallel, show_progress_bar, mem=None):
"""
Apply a sequential model with memblocks to the given input.
Args:
- model: nn.Sequential of blocks to apply
- x: input data, of dimensions NTCHW
- parallel: if True, parallelize over timesteps (fast but uses O(T) memory)
if False, each timestep will be processed sequentially (slow but uses O(1) memory)
- show_progress_bar: if True, enables tqdm progressbar display
Returns NTCHW tensor of output data.
"""
assert x.ndim == 5, f"TAEHV operates on NTCHW tensors, but got {x.ndim}-dim tensor"
N, T, C, H, W = x.shape
if parallel:
x = x.reshape(N * T, C, H, W)
for b in tqdm(model, disable=not show_progress_bar):
if isinstance(b, MemBlock):
NT, C, H, W = x.shape
T = NT // N
_x = x.reshape(N, T, C, H, W)
mem = F.pad(_x, (0, 0, 0, 0, 0, 0, 1, 0), value=0)[:, :T].reshape(x.shape)
x = b(x, mem)
else:
x = b(x)
NT, C, H, W = x.shape
T = NT // N
x = x.view(N, T, C, H, W)
else:
out = []
work_queue = [TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(N, T * C, H, W).chunk(T, dim=1))]
progress_bar = tqdm(range(T), disable=not show_progress_bar)
while work_queue:
xt, i = work_queue.pop(0)
if i == 0:
progress_bar.update(1)
if i == len(model):
out.append(xt)
else:
b = model[i]
if isinstance(b, MemBlock):
if mem[i] is None:
xt_new = b(xt, xt * 0)
mem[i] = xt
else:
xt_new = b(xt, mem[i])
mem[i].copy_(xt)
work_queue.insert(0, TWorkItem(xt_new, i + 1))
elif isinstance(b, TPool):
if mem[i] is None:
mem[i] = []
mem[i].append(xt)
if len(mem[i]) > b.stride:
raise ValueError("TPool internal state invalid.")
elif len(mem[i]) == b.stride:
N_, C_, H_, W_ = xt.shape
xt = b(torch.cat(mem[i], 1).view(N_ * b.stride, C_, H_, W_))
mem[i] = []
work_queue.insert(0, TWorkItem(xt, i + 1))
elif isinstance(b, TGrow):
xt = b(xt)
NT, C_, H_, W_ = xt.shape
for xt_next in reversed(xt.view(N, b.stride * C_, H_, W_).chunk(b.stride, 1)):
work_queue.insert(0, TWorkItem(xt_next, i + 1))
else:
xt = b(xt)
work_queue.insert(0, TWorkItem(xt, i + 1))
progress_bar.close()
x = torch.stack(out, 1)
return x, mem
# ----------------------------
# Decoder-only TAEHV
# ----------------------------
class TAEHV(nn.Module):
image_channels = 3
def __init__(self, checkpoint_path="taehv.pth", decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True), channels=[256, 128, 64, 64], latent_channels=16):
"""Initialize TAEHV (decoder-only) with built-in deepening after every ReLU.
Deepening config: how_many_each=1, k=3 (fixed as requested).
"""
super().__init__()
self.latent_channels = latent_channels
n_f = channels
self.frames_to_trim = 2 ** sum(decoder_time_upscale) - 1
# Build the decoder "skeleton"
base_decoder = nn.Sequential(
Clamp(),
conv(self.latent_channels, n_f[0]),
nn.ReLU(inplace=True),
MemBlock(n_f[0], n_f[0]),
MemBlock(n_f[0], n_f[0]),
MemBlock(n_f[0], n_f[0]),
nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1),
TGrow(n_f[0], 1),
conv(n_f[0], n_f[1], bias=False),
MemBlock(n_f[1], n_f[1]),
MemBlock(n_f[1], n_f[1]),
MemBlock(n_f[1], n_f[1]),
nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1),
TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1),
conv(n_f[1], n_f[2], bias=False),
MemBlock(n_f[2], n_f[2]),
MemBlock(n_f[2], n_f[2]),
MemBlock(n_f[2], n_f[2]),
nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1),
TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1),
conv(n_f[2], n_f[3], bias=False),
nn.ReLU(inplace=True),
conv(n_f[3], TAEHV.image_channels),
)
# Inline deepening: insert (IdentityConv2d(k=3) + ReLU) after every ReLU
self.decoder = self._apply_identity_deepen(base_decoder, how_many_each=1, k=3)
self.pixel_shuffle = PixelShuffle3d(4, 8, 8)
if checkpoint_path is not None:
missing_keys = self.load_state_dict(self.patch_tgrow_layers(torch.load(checkpoint_path, map_location="cpu", weights_only=True)), strict=False)
print("missing_keys", missing_keys)
# Initialize decoder mem state
self.mem = [None] * len(self.decoder)
@staticmethod
def _apply_identity_deepen(decoder: nn.Sequential, how_many_each=1, k=3) -> nn.Sequential:
"""Return a new Sequential where every nn.ReLU is followed by how_many_each*(IdentityConv2d(k)+ReLU)."""
new_layers = []
for b in decoder:
new_layers.append(b)
if isinstance(b, nn.ReLU):
# Deduce channel count from preceding layer
C = None
if len(new_layers) >= 2 and isinstance(new_layers[-2], nn.Conv2d):
C = new_layers[-2].out_channels
elif len(new_layers) >= 2 and isinstance(new_layers[-2], MemBlock):
C = new_layers[-2].conv[-1].out_channels
if C is not None:
for _ in range(how_many_each):
new_layers.append(IdentityConv2d(C, kernel_size=k, bias=False))
new_layers.append(nn.ReLU(inplace=True))
return nn.Sequential(*new_layers)
def patch_tgrow_layers(self, sd):
"""Patch TGrow layers to use a smaller kernel if needed (decoder-only)."""
new_sd = self.state_dict()
for i, layer in enumerate(self.decoder):
if isinstance(layer, TGrow):
key = f"decoder.{i}.conv.weight"
if key in sd and sd[key].shape[0] > new_sd[key].shape[0]:
sd[key] = sd[key][-new_sd[key].shape[0] :]
return sd
def decode_video(self, x, parallel=True, show_progress_bar=False, cond=None):
"""Decode a sequence of frames from latents.
x: NTCHW latent tensor; returns NTCHW RGB in ~[0, 1].
"""
trim_flag = self.mem[-8] is None # keeps original relative check
if cond is not None:
x = torch.cat([self.pixel_shuffle(cond), x], dim=2)
x, self.mem = apply_model_with_memblocks(self.decoder, x, parallel, show_progress_bar, mem=self.mem)
if trim_flag:
return x[:, self.frames_to_trim :]
return x
def forward(self, *args, **kwargs):
raise NotImplementedError("Decoder-only model: call decode_video(...) instead.")
def clean_mem(self):
self.mem = [None] * len(self.decoder)
class DotDict(dict):
__getattr__ = dict.__getitem__
__setattr__ = dict.__setitem__
class TAEW2_1DiffusersWrapper(nn.Module):
def __init__(self, pretrained_path=None, channels=[256, 128, 64, 64]):
super().__init__()
self.dtype = torch.bfloat16
self.device = "cuda"
self.taehv = TAEHV(pretrained_path, channels=channels).to(self.dtype)
self.temperal_downsample = [True, True, False] # [sic]
self.config = DotDict(scaling_factor=1.0, latents_mean=torch.zeros(16), z_dim=16, latents_std=torch.ones(16))
def decode(self, latents, return_dict=None):
n, c, t, h, w = latents.shape
return (self.taehv.decode_video(latents.transpose(1, 2), parallel=False).transpose(1, 2).mul_(2).sub_(1),)
def stream_decode_with_cond(self, latents, tiled=False, cond=None):
n, c, t, h, w = latents.shape
return self.taehv.decode_video(latents.transpose(1, 2), parallel=False, cond=cond).transpose(1, 2).mul_(2).sub_(1)
def clean_mem(self):
self.taehv.clean_mem()
# ----------------------------
# Simplified builder (no small, no transplant, no post-hoc deepening)
# ----------------------------
def build_tcdecoder(new_channels=[512, 256, 128, 128], device="cuda", dtype=torch.bfloat16, new_latent_channels=None):
"""
构建“更宽”的 decoder;深度增强(IdentityConv2d+ReLU)已在 TAEHV 内部完成。
- 不创建 small / 不做移植
- base_ckpt_path 参数保留但不使用(接口兼容)
返回:big (单个模型)
"""
if new_latent_channels is not None:
big = TAEHV(checkpoint_path=None, channels=new_channels, latent_channels=new_latent_channels).to(device).to(dtype).train()
else:
big = TAEHV(checkpoint_path=None, channels=new_channels).to(device).to(dtype).train()
big.clean_mem()
return big
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
CACHE_T = 2
class RMS_norm(nn.Module):
def __init__(self, dim, channel_first=True, images=True, bias=False):
super().__init__()
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
self.channel_first = channel_first
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(shape))
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
def forward(self, x):
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
class CausalConv3d(nn.Conv3d):
"""
Causal 3d convolusion.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0)
self.padding = (0, 0, 0)
def forward(self, x, cache_x=None):
padding = list(self._padding)
if cache_x is not None and self._padding[4] > 0:
cache_x = cache_x.to(x.device)
# print(cache_x.shape, x.shape)
x = torch.cat([cache_x, x], dim=2)
padding[4] -= cache_x.shape[2]
# print('cache!')
x = F.pad(x, padding, mode="replicate") # mode='replicate'
# print(x[0,0,:,0,0])
return super().forward(x)
class PixelShuffle3d(nn.Module):
def __init__(self, ff, hh, ww):
super().__init__()
self.ff = ff
self.hh = hh
self.ww = ww
def forward(self, x):
# x: (B, C, F, H, W)
return rearrange(x, "b c (f ff) (h hh) (w ww) -> b (c ff hh ww) f h w", ff=self.ff, hh=self.hh, ww=self.ww)
class Buffer_LQ4x_Proj(nn.Module):
def __init__(self, in_dim, out_dim, layer_num=30):
super().__init__()
self.ff = 1
self.hh = 16
self.ww = 16
self.hidden_dim1 = 2048
self.hidden_dim2 = 3072
self.layer_num = layer_num
self.pixel_shuffle = PixelShuffle3d(self.ff, self.hh, self.ww)
self.conv1 = CausalConv3d(in_dim * self.ff * self.hh * self.ww, self.hidden_dim1, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1)) # f -> f/2 h -> h w -> w
self.norm1 = RMS_norm(self.hidden_dim1, images=False)
self.act1 = nn.SiLU()
self.conv2 = CausalConv3d(self.hidden_dim1, self.hidden_dim2, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1)) # f -> f/2 h -> h w -> w
self.norm2 = RMS_norm(self.hidden_dim2, images=False)
self.act2 = nn.SiLU()
self.linear_layers = nn.ModuleList([nn.Linear(self.hidden_dim2, out_dim) for _ in range(layer_num)])
self.clip_idx = 0
def forward(self, video):
self.clear_cache()
# x: (B, C, F, H, W)
t = video.shape[2]
iter_ = 1 + (t - 1) // 4
first_frame = video[:, :, :1, :, :].repeat(1, 1, 3, 1, 1)
video = torch.cat([first_frame, video], dim=2)
# print(video.shape)
out_x = []
for i in range(iter_):
x = self.pixel_shuffle(video[:, :, i * 4 : (i + 1) * 4, :, :])
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
self.cache["conv1"] = cache1_x
x = self.conv1(x, self.cache["conv1"])
x = self.norm1(x)
x = self.act1(x)
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
self.cache["conv2"] = cache2_x
if i == 0:
continue
x = self.conv2(x, self.cache["conv2"])
x = self.norm2(x)
x = self.act2(x)
out_x.append(x)
out_x = torch.cat(out_x, dim=2)
# print(out_x.shape)
out_x = rearrange(out_x, "b c f h w -> b (f h w) c")
outputs = []
for i in range(self.layer_num):
outputs.append(self.linear_layers[i](out_x))
return outputs
def clear_cache(self):
self.cache = {}
self.cache["conv1"] = None
self.cache["conv2"] = None
self.clip_idx = 0
def stream_forward(self, video_clip):
if self.clip_idx == 0:
# self.clear_cache()
first_frame = video_clip[:, :, :1, :, :].repeat(1, 1, 3, 1, 1)
video_clip = torch.cat([first_frame, video_clip], dim=2)
x = self.pixel_shuffle(video_clip)
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
self.cache["conv1"] = cache1_x
x = self.conv1(x, self.cache["conv1"])
x = self.norm1(x)
x = self.act1(x)
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
self.cache["conv2"] = cache2_x
self.clip_idx += 1
return None
else:
x = self.pixel_shuffle(video_clip)
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
self.cache["conv1"] = cache1_x
x = self.conv1(x, self.cache["conv1"])
x = self.norm1(x)
x = self.act1(x)
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
self.cache["conv2"] = cache2_x
x = self.conv2(x, self.cache["conv2"])
x = self.norm2(x)
x = self.act2(x)
out_x = rearrange(x, "b c f h w -> b (f h w) c")
outputs = []
for i in range(self.layer_num):
outputs.append(self.linear_layers[i](out_x))
self.clip_idx += 1
return outputs
import os
from typing import Optional
import torch
from torch.nn import functional as F
from lightx2v.utils.profiler import *
try:
from diffsynth import FlashVSRTinyPipeline, ModelManager
except ImportError:
ModelManager = None
FlashVSRTinyPipeline = None
from .utils.TCDecoder import build_tcdecoder
from .utils.utils import Buffer_LQ4x_Proj
def largest_8n1_leq(n): # 8n+1
return 0 if n < 1 else ((n - 1) // 8) * 8 + 1
def compute_scaled_and_target_dims(w0: int, h0: int, scale: float = 4.0, multiple: int = 128):
if w0 <= 0 or h0 <= 0:
raise ValueError("Invalid original size")
if scale <= 0:
raise ValueError("scale must be > 0")
sW = int(round(w0 * scale))
sH = int(round(h0 * scale))
tW = (sW // multiple) * multiple
tH = (sH // multiple) * multiple
if tW == 0 or tH == 0:
raise ValueError(f"Scaled size too small ({sW}x{sH}) for multiple={multiple}. Increase scale (got {scale}).")
return sW, sH, tW, tH
def prepare_input_tensor(input_tensor, scale: float = 2.0, dtype=torch.bfloat16, device="cuda"):
"""
视频预处理: [T,H,W,3] -> [1,C,F,H,W]
1. GPU 上完成插值 + 中心裁剪
2. 自动 pad 帧数到 8n-3
"""
input_tensor = input_tensor.to(device=device, dtype=torch.float32) # [T,H,W,3]
total, h0, w0, _ = input_tensor.shape
# 计算缩放与目标分辨率
sW, sH, tW, tH = compute_scaled_and_target_dims(w0, h0, scale=scale, multiple=128)
print(f"Scaled (x{scale:.2f}): {sW}x{sH} -> Target: {tW}x{tH}")
# Pad 帧数到 8n-3
idx = list(range(total)) + [total - 1] * 4
F_target = largest_8n1_leq(len(idx))
if F_target == 0:
raise RuntimeError(f"Not enough frames after padding. Got {len(idx)}.")
idx = idx[:F_target]
print(f"Target Frames (8n-3): {F_target - 4}")
# 取帧并转为 tensor 格式 [B,C,H,W]
frames = input_tensor[idx] # [F,H,W,3]
frames = frames.permute(0, 3, 1, 2) * 2.0 - 1.0 # [F,3,H,W] -> [-1,1]
# 上采样 (Bilinear)
frames = F.interpolate(frames, scale_factor=scale, mode="bicubic", align_corners=False)
_, _, sH, sW = frames.shape
# 中心裁剪
left = (sW - tW) // 2
top = (sH - tH) // 2
frames = frames[:, :, top : top + tH, left : left + tW]
# 输出 [1, C, F, H, W]
vid = frames.permute(1, 0, 2, 3).unsqueeze(0).to(dtype)
return vid, tH, tW, F_target
def init_pipeline(model_path):
# print(torch.cuda.current_device(), torch.cuda.get_device_name(torch.cuda.current_device()))
mm = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
mm.load_models(
[
model_path + "/diffusion_pytorch_model_streaming_dmd.safetensors",
]
)
pipe = FlashVSRTinyPipeline.from_model_manager(mm, device="cuda")
pipe.denoising_model().LQ_proj_in = Buffer_LQ4x_Proj(in_dim=3, out_dim=1536, layer_num=1).to("cuda", dtype=torch.bfloat16)
LQ_proj_in_path = model_path + "/LQ_proj_in.ckpt"
if os.path.exists(LQ_proj_in_path):
pipe.denoising_model().LQ_proj_in.load_state_dict(torch.load(LQ_proj_in_path, map_location="cpu"), strict=True)
pipe.denoising_model().LQ_proj_in.to("cuda")
multi_scale_channels = [512, 256, 128, 128]
pipe.TCDecoder = build_tcdecoder(new_channels=multi_scale_channels, new_latent_channels=16 + 768)
mis = pipe.TCDecoder.load_state_dict(torch.load(model_path + "/TCDecoder.ckpt"), strict=False)
# print(mis)
pipe.to("cuda")
pipe.enable_vram_management(num_persistent_param_in_dit=None)
pipe.init_cross_kv()
pipe.load_models_to_device(["dit", "vae"])
return pipe
class VSRWrapper:
def __init__(self, model_path, device: Optional[torch.device] = None):
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Setup torch for optimal performance
torch.set_grad_enabled(False)
if torch.cuda.is_available():
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
# Load model
self.dtype, self.device = torch.bfloat16, "cuda"
self.sparse_ratio = 2.0 # Recommended: 1.5 or 2.0. 1.5 → faster; 2.0 → more stable.
with ProfilingContext4DebugL2("Load VSR model"):
self.pipe = init_pipeline(model_path)
@ProfilingContext4DebugL2("VSR video")
def super_resolve_frames(
self,
video: torch.Tensor, # [T,H,W,C]
seed: float = 0.0,
scale: float = 2.0,
) -> torch.Tensor:
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
LQ, th, tw, F = prepare_input_tensor(video, scale=scale, dtype=self.dtype, device=self.device)
video = self.pipe(
prompt="",
negative_prompt="",
cfg_scale=1.0,
num_inference_steps=1,
seed=seed,
LQ_video=LQ,
num_frames=F,
height=th,
width=tw,
is_full_block=False,
if_buffer=True,
topk_ratio=self.sparse_ratio * 768 * 1280 / (th * tw),
kv_ratio=3.0,
local_range=11, # Recommended: 9 or 11. local_range=9 → sharper details; 11 → more stable results.
color_fix=True,
)
video = (video + 1.0) / 2.0 # 将 [-1,1] 映射到 [0,1]
video = video.permute(1, 2, 3, 0).clamp(0.0, 1.0) # [C,T,H,W] -> [T,H,W,C]
return video
......@@ -652,6 +652,14 @@ class WanAudioRunner(WanRunner): # type:ignore
target_fps=target_fps,
)
if "video_super_resolution" in self.config and self.vsr_model is not None:
logger.info(f"Applying video super resolution with scale {self.config['video_super_resolution']['scale']}")
video_seg = self.vsr_model.super_resolve_frames(
video_seg,
seed=self.config["video_super_resolution"]["seed"],
scale=self.config["video_super_resolution"]["scale"],
)
if self.va_recorder:
self.va_recorder.pub_livestream(video_seg, audio_seg)
elif self.input_info.return_result_tensor:
......
#!/bin/bash
lightx2v_path=""
model_path=""
export CUDA_VISIBLE_DEVICES=0
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export SENSITIVE_LAYER_DTYPE=None
python -m lightx2v.infer \
--model_cls seko_talk \
--task s2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/seko_talk/seko_talk_17_base_vsr.json \
--prompt "The video features a male speaking to the camera with arms spread out, a slightly furrowed brow, and a focused gaze." \
--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--image_path ${lightx2v_path}/assets/inputs/audio/seko_input.png \
--audio_path ${lightx2v_path}/assets/inputs/audio/seko_input.mp3 \
--save_result_path ${lightx2v_path}/save_results/output_lightx2v_seko_talk.mp4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment