Commit 486e6279 authored by root's avatar root
Browse files

Add support for running lightx2v on 8 GB GPUs

parent e74270f5
{
"infer_steps": 40,
"target_video_length": 81,
"target_height": 480,
"target_width": 832,
"attention_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 5,
"sample_shift": 5,
"enable_cfg": true,
"cpu_offload": true,
"offload_granularity": "block",
"mm_config": {
"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F",
"weight_auto_quant": true
}
}
{
"infer_steps": 40,
"target_video_length": 81,
"target_height": 480,
"target_width": 832,
"attention_type": "sage_attn2",
"seed": 42,
"sample_guide_scale": 5,
"sample_shift": 5,
"enable_cfg": true,
"cpu_offload": true,
"offload_granularity": "phase",
"mm_config": {
"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F",
"weight_auto_quant": true
},
"use_tiling_vae": true,
"tiny_vae": true,
"tiny_vae_path": "/mnt/afs_2/gushiqiao/x2v_models/taew2_1.pth",
"text_encoder_offload_granularity": "block"
}
{
"infer_steps": 50,
"target_video_length": 81,
"text_len": 512,
"target_height": 480,
"target_width": 832,
"attention_type": "sage_attn2",
"seed": 42,
"sample_guide_scale": 6,
"sample_shift": 8,
"enable_cfg": true,
"cpu_offload": true,
"offload_granularity": "block",
"mm_config": {
"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F",
"weight_auto_quant": true
}
}
{
"infer_steps": 50,
"target_video_length": 81,
"text_len": 512,
"target_height": 480,
"target_width": 832,
"attention_type": "sage_attn2",
"seed": 42,
"sample_guide_scale": 6,
"sample_shift": 8,
"enable_cfg": true,
"cpu_offload": true,
"offload_granularity": "phase",
"mm_config": {
"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F",
"weight_auto_quant": true
},
"tiny_vae": true,
"tiny_vae_path": "/mnt/afs_2/gushiqiao/x2v_models/taew2_1.pth",
"text_encoder_offload_granularity": "block"
}
...@@ -53,8 +53,14 @@ class WeightModule: ...@@ -53,8 +53,14 @@ class WeightModule:
self._parameters[name].to_cpu() self._parameters[name].to_cpu()
setattr(self, name, self._parameters[name]) setattr(self, name, self._parameters[name])
for module in self._modules.values(): for module in self._modules.values():
if module is not None and hasattr(module, "to_cpu"): if isinstance(module, WeightModuleList):
module.to_cpu() for i in range(len(module)):
for m in module[i]._modules.values():
if m is not None and hasattr(m, "to_cpu"):
m.to_cpu()
else:
if module is not None and hasattr(module, "to_cpu"):
module.to_cpu()
def to_cuda(self): def to_cuda(self):
for name, param in self._parameters.items(): for name, param in self._parameters.items():
...@@ -65,10 +71,16 @@ class WeightModule: ...@@ -65,10 +71,16 @@ class WeightModule:
self._parameters[name].to_cuda() self._parameters[name].to_cuda()
setattr(self, name, self._parameters[name]) setattr(self, name, self._parameters[name])
for module in self._modules.values(): for module in self._modules.values():
if module is not None and hasattr(module, "to_cuda"): if isinstance(module, WeightModuleList):
module.to_cuda() for i in range(len(module)):
for m in module[i]._modules.values():
def to_cpu_sync(self): if m is not None and hasattr(m, "to_cuda"):
m.to_cuda()
else:
if module is not None and hasattr(module, "to_cuda"):
module.to_cuda()
def to_cpu_async(self):
for name, param in self._parameters.items(): for name, param in self._parameters.items():
if param is not None: if param is not None:
if hasattr(param, "cpu"): if hasattr(param, "cpu"):
...@@ -78,10 +90,16 @@ class WeightModule: ...@@ -78,10 +90,16 @@ class WeightModule:
self._parameters[name].to_cpu(non_blocking=True) self._parameters[name].to_cpu(non_blocking=True)
setattr(self, name, self._parameters[name]) setattr(self, name, self._parameters[name])
for module in self._modules.values(): for module in self._modules.values():
if module is not None and hasattr(module, "to_cpu"): if isinstance(module, WeightModuleList):
module.to_cpu(non_blocking=True) for i in range(len(module)):
for m in module[i]._modules.values():
def to_cuda_sync(self): if m is not None and hasattr(m, "to_cpu"):
m.to_cpu(non_blocking=True)
else:
if module is not None and hasattr(module, "to_cpu"):
module.to_cpu(non_blocking=True)
def to_cuda_async(self):
for name, param in self._parameters.items(): for name, param in self._parameters.items():
if param is not None: if param is not None:
if hasattr(param, "cuda"): if hasattr(param, "cuda"):
...@@ -90,8 +108,14 @@ class WeightModule: ...@@ -90,8 +108,14 @@ class WeightModule:
self._parameters[name].to_cuda(non_blocking=True) self._parameters[name].to_cuda(non_blocking=True)
setattr(self, name, self._parameters[name]) setattr(self, name, self._parameters[name])
for module in self._modules.values(): for module in self._modules.values():
if module is not None and hasattr(module, "to_cuda"): if isinstance(module, WeightModuleList):
module.to_cuda(non_blocking=True) for i in range(len(module)):
for m in module[i]._modules.values():
if m is not None and hasattr(m, "to_cuda"):
m.to_cuda(non_blocking=True)
else:
if module is not None and hasattr(module, "to_cuda"):
module.to_cuda(non_blocking=True)
class WeightModuleList(WeightModule): class WeightModuleList(WeightModule):
......
import torch import torch
class WeightStreamManager(object): class WeightAsyncStreamManager(object):
def __init__(self): def __init__(self):
self.active_weights = [None for _ in range(2)]
self.active_weights = [None for _ in range(2)] self.active_weights = [None for _ in range(2)]
self.compute_stream = torch.cuda.Stream(priority=-1) self.compute_stream = torch.cuda.Stream(priority=-1)
self.load_stream = torch.cuda.Stream(priority=0) self.load_stream = torch.cuda.Stream(priority=0)
...@@ -10,9 +11,9 @@ class WeightStreamManager(object): ...@@ -10,9 +11,9 @@ class WeightStreamManager(object):
def prefetch_weights(self, block_idx, blocks_weights): def prefetch_weights(self, block_idx, blocks_weights):
with torch.cuda.stream(self.load_stream): with torch.cuda.stream(self.load_stream):
if self.active_weights[1] is not None: if self.active_weights[1] is not None:
self.active_weights[1].to_cpu_sync() self.active_weights[1].to_cpu_async()
new_weights = blocks_weights[block_idx] new_weights = blocks_weights[block_idx]
new_weights.to_cuda_sync() new_weights.to_cuda_async()
self.active_weights[1] = new_weights self.active_weights[1] = new_weights
def swap_weights(self): def swap_weights(self):
...@@ -23,3 +24,17 @@ class WeightStreamManager(object): ...@@ -23,3 +24,17 @@ class WeightStreamManager(object):
self.active_weights[1], self.active_weights[1],
self.active_weights[0], self.active_weights[0],
) )
def prefetch_phase(self, block_idx, phase_idx, blocks):
with torch.cuda.stream(self.load_stream):
if self.active_weights[1] is not None:
_, old_phase = self.active_weights[1]
old_phase.to_cpu_async()
new_phase = blocks[block_idx].compute_phases[phase_idx]
new_phase.to_cuda_async()
self.active_weights[1] = (phase_idx, new_phase)
def swap_phases(self):
self.compute_stream.synchronize()
self.load_stream.synchronize()
self.active_weights[0], self.active_weights[1] = self.active_weights[1], self.active_weights[0]
...@@ -39,15 +39,15 @@ class Conv3dWeight(Conv3dWeightTemplate): ...@@ -39,15 +39,15 @@ class Conv3dWeight(Conv3dWeightTemplate):
input_tensor = torch.nn.functional.conv3d(input_tensor, weight=self.weight, bias=self.bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups) input_tensor = torch.nn.functional.conv3d(input_tensor, weight=self.weight, bias=self.bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups)
return input_tensor return input_tensor
def to_cpu(self): def to_cpu(self, non_blocking=False):
self.weight = self.weight.cpu() self.weight = self.weight.to("cpu", non_blocking=non_blocking)
if self.bias is not None: if self.bias is not None:
self.bias = self.bias.cpu() self.bias = self.bias.to("cpu", non_blocking=non_blocking)
def to_cuda(self): def to_cuda(self, non_blocking=False):
self.weight = self.weight.cuda() self.weight = self.weight.cuda(non_blocking=non_blocking)
if self.bias is not None: if self.bias is not None:
self.bias = self.bias.cuda() self.bias = self.bias.cuda(non_blocking=non_blocking)
def state_dict(self, destination=None): def state_dict(self, destination=None):
if destination is None: if destination is None:
......
...@@ -256,8 +256,10 @@ class T5Encoder(nn.Module): ...@@ -256,8 +256,10 @@ class T5Encoder(nn.Module):
num_buckets, num_buckets,
shared_pos=True, shared_pos=True,
dropout=0.1, dropout=0.1,
cpu_offload=False,
): ):
super(T5Encoder, self).__init__() super(T5Encoder, self).__init__()
self.cpu_offload = cpu_offload
self.dim = dim self.dim = dim
self.dim_attn = dim_attn self.dim_attn = dim_attn
self.dim_ffn = dim_ffn self.dim_ffn = dim_ffn
...@@ -277,12 +279,28 @@ class T5Encoder(nn.Module): ...@@ -277,12 +279,28 @@ class T5Encoder(nn.Module):
self.apply(init_weights) self.apply(init_weights)
def forward(self, ids, mask=None): def forward(self, ids, mask=None):
if self.cpu_offload:
self.token_embedding = self.token_embedding.cuda()
x = self.token_embedding(ids) x = self.token_embedding(ids)
if self.cpu_offload:
self.token_embedding = self.token_embedding.cpu()
x = self.dropout(x) x = self.dropout(x)
if self.cpu_offload and self.pos_embedding is not None:
self.pos_embedding = self.pos_embedding.cuda()
e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None
if self.cpu_offload and self.pos_embedding is not None:
self.pos_embedding = self.pos_embedding.cpu()
for block in self.blocks: for block in self.blocks:
if self.cpu_offload:
block = block.cuda()
x = block(x, mask, pos_bias=e) x = block(x, mask, pos_bias=e)
if self.cpu_offload:
block = block.cpu()
if self.cpu_offload:
self.norm = self.norm.cuda()
x = self.norm(x) x = self.norm(x)
if self.cpu_offload:
self.norm = self.norm.cpu()
x = self.dropout(x) x = self.dropout(x)
return x return x
...@@ -432,15 +450,7 @@ def _t5( ...@@ -432,15 +450,7 @@ def _t5(
# set device # set device
model = model.to(dtype=dtype, device=device) model = model.to(dtype=dtype, device=device)
return model
# init tokenizer
if return_tokenizer:
from .tokenizers import HuggingfaceTokenizer
tokenizer = HuggingfaceTokenizer(f"google/{name}", **tokenizer_kwargs)
return model, tokenizer
else:
return model
def umt5_xxl(**kwargs): def umt5_xxl(**kwargs):
...@@ -470,15 +480,33 @@ class T5EncoderModel: ...@@ -470,15 +480,33 @@ class T5EncoderModel:
checkpoint_path=None, checkpoint_path=None,
tokenizer_path=None, tokenizer_path=None,
shard_fn=None, shard_fn=None,
cpu_offload=False,
offload_granularity="model",
): ):
self.text_len = text_len self.text_len = text_len
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
self.checkpoint_path = checkpoint_path self.checkpoint_path = checkpoint_path
self.tokenizer_path = tokenizer_path self.tokenizer_path = tokenizer_path
self.offload_granularity = offload_granularity
# sync cpu offload
self.cpu_offload = cpu_offload
if self.cpu_offload:
assert self.offload_granularity in ["block", "model"]
# init model # init model
model = umt5_xxl(encoder_only=True, return_tokenizer=False, dtype=dtype, device=device).eval().requires_grad_(False) model = (
umt5_xxl(
encoder_only=True,
return_tokenizer=False,
dtype=dtype,
device=device,
cpu_offload=cpu_offload if self.offload_granularity == "block" else False,
)
.eval()
.requires_grad_(False)
)
logging.info(f"loading {checkpoint_path}") logging.info(f"loading {checkpoint_path}")
model.load_state_dict(torch.load(checkpoint_path, map_location="cpu", weights_only=True)) model.load_state_dict(torch.load(checkpoint_path, map_location="cpu", weights_only=True))
self.model = model self.model = model
...@@ -495,8 +523,8 @@ class T5EncoderModel: ...@@ -495,8 +523,8 @@ class T5EncoderModel:
def to_cuda(self): def to_cuda(self):
self.model = self.model.to("cuda") self.model = self.model.to("cuda")
def infer(self, texts, config): def infer(self, texts):
if config.cpu_offload: if self.cpu_offload and self.offload_granularity == "model":
self.to_cuda() self.to_cuda()
ids, mask = self.tokenizer(texts, return_mask=True, add_special_tokens=True) ids, mask = self.tokenizer(texts, return_mask=True, add_special_tokens=True)
...@@ -505,7 +533,7 @@ class T5EncoderModel: ...@@ -505,7 +533,7 @@ class T5EncoderModel:
seq_lens = mask.gt(0).sum(dim=1).long() seq_lens = mask.gt(0).sum(dim=1).long()
context = self.model(ids, mask) context = self.model(ids, mask)
if config.cpu_offload: if self.cpu_offload and self.offload_granularity == "model":
self.to_cpu() self.to_cpu()
return [u[:v] for u, v in zip(context, seq_lens)] return [u[:v] for u, v in zip(context, seq_lens)]
......
import torch import torch
from einops import rearrange from einops import rearrange
from .utils_bf16 import apply_rotary_emb from .utils_bf16 import apply_rotary_emb
from lightx2v.common.offload.manager import WeightStreamManager from lightx2v.common.offload.manager import WeightAsyncStreamManager
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
...@@ -16,8 +16,8 @@ class HunyuanTransformerInfer: ...@@ -16,8 +16,8 @@ class HunyuanTransformerInfer:
self.mlp_hidden_dim = 12288 self.mlp_hidden_dim = 12288
self.parallel_attention = None self.parallel_attention = None
if self.config["cpu_offload"]: if self.config["cpu_offload"]:
self.double_weights_stream_mgr = WeightStreamManager() self.double_weights_stream_mgr = WeightAsyncStreamManager()
self.single_weights_stream_mgr = WeightStreamManager() self.single_weights_stream_mgr = WeightAsyncStreamManager()
self.infer_func = self._infer_with_offload self.infer_func = self._infer_with_offload
else: else:
self.infer_func = self._infer_without_offload self.infer_func = self._infer_without_offload
......
import torch import torch
import math import math
from ..utils import compute_freqs, compute_freqs_causvid, compute_freqs_dist, apply_rotary_emb from ..utils import compute_freqs, compute_freqs_causvid, compute_freqs_dist, apply_rotary_emb
from lightx2v.common.offload.manager import WeightStreamManager
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from ..transformer_infer import WanTransformerInfer from ..transformer_infer import WanTransformerInfer
......
import torch import torch
import math import math
from .utils import rope_params, sinusoidal_embedding_1d from .utils import rope_params, sinusoidal_embedding_1d
import torch.cuda.amp as amp
class WanPreInfer: class WanPreInfer:
......
import torch import torch
from .utils import compute_freqs, compute_freqs_dist, apply_rotary_emb from .utils import compute_freqs, compute_freqs_dist, apply_rotary_emb
from lightx2v.common.offload.manager import WeightStreamManager from lightx2v.common.offload.manager import WeightAsyncStreamManager
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
...@@ -15,30 +15,32 @@ class WanTransformerInfer: ...@@ -15,30 +15,32 @@ class WanTransformerInfer:
self.window_size = config.get("window_size", (-1, -1)) self.window_size = config.get("window_size", (-1, -1))
self.parallel_attention = None self.parallel_attention = None
if self.config["cpu_offload"]: if self.config["cpu_offload"]:
self.weights_stream_mgr = WeightStreamManager() offload_granularity = self.config.get("offload_granularity", "block")
self.infer_func = self._infer_with_offload self.weights_stream_mgr = WeightAsyncStreamManager()
if offload_granularity == "block":
self.infer_func = self._infer_with_offload
elif offload_granularity == "phase":
self.infer_func = self._infer_with_phases_offload
else: else:
self.infer_func = self._infer_without_offload self.infer_func = self._infer_without_offload
def set_scheduler(self, scheduler): def set_scheduler(self, scheduler):
self.scheduler = scheduler self.scheduler = scheduler
def _calculate_q_k_len(self, q, k, k_lens): def _calculate_q_k_len(self, q, k_lens):
lq, nq, c1 = q.size()
lk, nk, c1_k = k.size()
# Handle query and key lengths (use `q_lens` and `k_lens` or set them to Lq and Lk if None) # Handle query and key lengths (use `q_lens` and `k_lens` or set them to Lq and Lk if None)
q_lens = torch.tensor([lq], dtype=torch.int32, device=q.device) q_lens = torch.tensor([q.size(0)], dtype=torch.int32, device=q.device)
# We don't have a batch dimension anymore, so directly use the `q_lens` and `k_lens` values # We don't have a batch dimension anymore, so directly use the `q_lens` and `k_lens` values
cu_seqlens_q = torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32) cu_seqlens_q = torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32)
cu_seqlens_k = torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32) cu_seqlens_k = torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32)
return cu_seqlens_q, cu_seqlens_k, lq, lk return cu_seqlens_q, cu_seqlens_k
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE()) @torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): def infer(self, weights, grid_sizes, x, embed0, seq_lens, freqs, context):
return self.infer_func(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context) return self.infer_func(weights, grid_sizes, x, embed0, seq_lens, freqs, context)
def _infer_with_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): def _infer_with_offload(self, weights, grid_sizes, x, embed0, seq_lens, freqs, context):
for block_idx in range(self.blocks_num): for block_idx in range(self.blocks_num):
if block_idx == 0: if block_idx == 0:
self.weights_stream_mgr.active_weights[0] = weights.blocks[0] self.weights_stream_mgr.active_weights[0] = weights.blocks[0]
...@@ -48,7 +50,6 @@ class WanTransformerInfer: ...@@ -48,7 +50,6 @@ class WanTransformerInfer:
x = self.infer_block( x = self.infer_block(
self.weights_stream_mgr.active_weights[0], self.weights_stream_mgr.active_weights[0],
grid_sizes, grid_sizes,
embed,
x, x,
embed0, embed0,
seq_lens, seq_lens,
...@@ -62,12 +63,62 @@ class WanTransformerInfer: ...@@ -62,12 +63,62 @@ class WanTransformerInfer:
return x return x
def _infer_without_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): def _infer_with_phases_offload(self, weights, grid_sizes, x, embed0, seq_lens, freqs, context):
for block_idx in range(weights.blocks_num):
weights.blocks[block_idx].modulation.to_cuda()
if embed0.dim() == 3:
modulation = weights.blocks[block_idx].modulation.tensor.unsqueeze(2)
current_embed0 = (modulation + embed0).chunk(6, dim=1)
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = [ei.squeeze(1) for ei in current_embed0]
elif embed0.dim() == 2:
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (weights.blocks[block_idx].modulation.tensor + embed0).chunk(6, dim=1)
for phase_idx in range(3):
if block_idx == 0 and phase_idx == 0:
phase = weights.blocks[block_idx].compute_phases[phase_idx]
phase.to_cuda()
self.weights_stream_mgr.active_weights[0] = (phase_idx, phase)
with torch.cuda.stream(self.weights_stream_mgr.compute_stream):
cur_phase_idx, cur_phase = self.weights_stream_mgr.active_weights[0]
if cur_phase_idx == 0:
x = self._infer_self_attn(
cur_phase,
x,
shift_msa,
scale_msa,
gate_msa,
grid_sizes,
freqs,
seq_lens,
)
elif cur_phase_idx == 1:
x = self._infer_cross_attn(cur_phase, x, context)
elif cur_phase_idx == 2:
x = self._infer_ffn(cur_phase, x, c_shift_msa, c_scale_msa, c_gate_msa)
is_last_phase = block_idx == weights.blocks_num - 1 and phase_idx == 2
if not is_last_phase:
next_block_idx = block_idx + 1 if cur_phase_idx == 2 else block_idx
next_phase_idx = (cur_phase_idx + 1) % 3
self.weights_stream_mgr.prefetch_phase(next_block_idx, next_phase_idx, weights.blocks)
self.weights_stream_mgr.swap_phases()
weights.blocks[block_idx].modulation.to_cpu()
torch.cuda.empty_cache()
return x
def _infer_without_offload(self, weights, grid_sizes, x, embed0, seq_lens, freqs, context):
for block_idx in range(self.blocks_num): for block_idx in range(self.blocks_num):
x = self.infer_block( x = self.infer_block(
weights.blocks[block_idx], weights.blocks[block_idx],
grid_sizes, grid_sizes,
embed,
x, x,
embed0, embed0,
seq_lens, seq_lens,
...@@ -76,21 +127,13 @@ class WanTransformerInfer: ...@@ -76,21 +127,13 @@ class WanTransformerInfer:
) )
return x return x
def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): def _infer_self_attn(self, weights, x, shift_msa, scale_msa, gate_msa, grid_sizes, freqs, seq_lens):
if embed0.dim() == 3:
modulation = weights.modulation.tensor.unsqueeze(2) # 1, 6, 1, dim
embed0 = embed0.unsqueeze(0) #
embed0 = (modulation + embed0).chunk(6, dim=1)
embed0 = [ei.squeeze(1) for ei in embed0]
elif embed0.dim() == 2:
embed0 = (weights.modulation.tensor + embed0).chunk(6, dim=1)
if hasattr(weights, "smooth_norm1_weight"): if hasattr(weights, "smooth_norm1_weight"):
norm1_weight = (1 + embed0[1]) * weights.smooth_norm1_weight.tensor norm1_weight = (1 + scale_msa) * weights.smooth_norm1_weight.tensor
norm1_bias = embed0[0] * weights.smooth_norm1_bias.tensor norm1_bias = shift_msa * weights.smooth_norm1_bias.tensor
else: else:
norm1_weight = 1 + embed0[1] norm1_weight = 1 + scale_msa
norm1_bias = embed0[0] norm1_bias = shift_msa
norm1_out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6) norm1_out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6)
norm1_out = (norm1_out * norm1_weight + norm1_bias).squeeze(0) norm1_out = (norm1_out * norm1_weight + norm1_bias).squeeze(0)
...@@ -108,7 +151,7 @@ class WanTransformerInfer: ...@@ -108,7 +151,7 @@ class WanTransformerInfer:
q = apply_rotary_emb(q, freqs_i) q = apply_rotary_emb(q, freqs_i)
k = apply_rotary_emb(k, freqs_i) k = apply_rotary_emb(k, freqs_i)
cu_seqlens_q, cu_seqlens_k, lq, lk = self._calculate_q_k_len(q, k, k_lens=seq_lens) cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len(q, k_lens=seq_lens)
if not self.parallel_attention: if not self.parallel_attention:
attn_out = weights.self_attn_1.apply( attn_out = weights.self_attn_1.apply(
...@@ -117,8 +160,8 @@ class WanTransformerInfer: ...@@ -117,8 +160,8 @@ class WanTransformerInfer:
v=v, v=v,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_k, cu_seqlens_kv=cu_seqlens_k,
max_seqlen_q=lq, max_seqlen_q=q.size(0),
max_seqlen_kv=lk, max_seqlen_kv=k.size(0),
model_cls=self.config["model_cls"], model_cls=self.config["model_cls"],
) )
else: else:
...@@ -129,25 +172,30 @@ class WanTransformerInfer: ...@@ -129,25 +172,30 @@ class WanTransformerInfer:
v=v, v=v,
img_qkv_len=q.shape[0], img_qkv_len=q.shape[0],
cu_seqlens_qkv=cu_seqlens_q, cu_seqlens_qkv=cu_seqlens_q,
# cu_seqlens_qkv=cu_seqlens_qkv,
# max_seqlen_qkv=max_seqlen_qkv,
) )
y = weights.self_attn_o.apply(attn_out)
x = x + y * embed0[2].squeeze(0) y = weights.self_attn_o.apply(attn_out)
x.add_(y * gate_msa.squeeze(0))
return x
def _infer_cross_attn(self, weights, x, context):
norm3_out = weights.norm3.apply(x) norm3_out = weights.norm3.apply(x)
if self.task == "i2v": if self.task == "i2v":
context_img = context[:257] context_img = context[:257]
context = context[257:] context = context[257:]
else:
context_img = None
n, d = self.num_heads, self.head_dim n, d = self.num_heads, self.head_dim
q = weights.cross_attn_norm_q.apply(weights.cross_attn_q.apply(norm3_out)).view(-1, n, d) q = weights.cross_attn_norm_q.apply(weights.cross_attn_q.apply(norm3_out)).view(-1, n, d)
k = weights.cross_attn_norm_k.apply(weights.cross_attn_k.apply(context)).view(-1, n, d) k = weights.cross_attn_norm_k.apply(weights.cross_attn_k.apply(context)).view(-1, n, d)
v = weights.cross_attn_v.apply(context).view(-1, n, d) v = weights.cross_attn_v.apply(context).view(-1, n, d)
cu_seqlens_q, cu_seqlens_k, lq, lk = self._calculate_q_k_len(q, k, k_lens=torch.tensor([k.size(0)], dtype=torch.int32, device=k.device)) cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len(
q,
k_lens=torch.tensor([k.size(0)], dtype=torch.int32, device=k.device),
)
attn_out = weights.cross_attn_1.apply( attn_out = weights.cross_attn_1.apply(
q=q, q=q,
...@@ -155,18 +203,17 @@ class WanTransformerInfer: ...@@ -155,18 +203,17 @@ class WanTransformerInfer:
v=v, v=v,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_k, cu_seqlens_kv=cu_seqlens_k,
max_seqlen_q=lq, max_seqlen_q=q.size(0),
max_seqlen_kv=lk, max_seqlen_kv=k.size(0),
model_cls=self.config["model_cls"], model_cls=self.config["model_cls"],
) )
if self.task == "i2v": if self.task == "i2v" and context_img is not None:
k_img = weights.cross_attn_norm_k_img.apply(weights.cross_attn_k_img.apply(context_img)).view(-1, n, d) k_img = weights.cross_attn_norm_k_img.apply(weights.cross_attn_k_img.apply(context_img)).view(-1, n, d)
v_img = weights.cross_attn_v_img.apply(context_img).view(-1, n, d) v_img = weights.cross_attn_v_img.apply(context_img).view(-1, n, d)
cu_seqlens_q, cu_seqlens_k, lq, lk = self._calculate_q_k_len( cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len(
q, q,
k_img,
k_lens=torch.tensor([k_img.size(0)], dtype=torch.int32, device=k.device), k_lens=torch.tensor([k_img.size(0)], dtype=torch.int32, device=k.device),
) )
...@@ -176,28 +223,50 @@ class WanTransformerInfer: ...@@ -176,28 +223,50 @@ class WanTransformerInfer:
v=v_img, v=v_img,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_k, cu_seqlens_kv=cu_seqlens_k,
max_seqlen_q=lq, max_seqlen_q=q.size(0),
max_seqlen_kv=lk, max_seqlen_kv=k_img.size(0),
model_cls=self.config["model_cls"], model_cls=self.config["model_cls"],
) )
attn_out = attn_out + img_attn_out attn_out = attn_out + img_attn_out
attn_out = weights.cross_attn_o.apply(attn_out) attn_out = weights.cross_attn_o.apply(attn_out)
x.add_(attn_out)
return x
x = x + attn_out def _infer_ffn(self, weights, x, c_shift_msa, c_scale_msa, c_gate_msa):
if hasattr(weights, "smooth_norm2_weight"): if hasattr(weights, "smooth_norm2_weight"):
norm2_weight = (1 + embed0[4].squeeze(0)) * weights.smooth_norm2_weight.tensor norm2_weight = (1 + c_scale_msa.squeeze(0)) * weights.smooth_norm2_weight.tensor
norm2_bias = embed0[3].squeeze(0) * weights.smooth_norm2_bias.tensor norm2_bias = c_shift_msa.squeeze(0) * weights.smooth_norm2_bias.tensor
else: else:
norm2_weight = 1 + embed0[4].squeeze(0) norm2_weight = 1 + c_scale_msa.squeeze(0)
norm2_bias = embed0[3].squeeze(0) norm2_bias = c_shift_msa.squeeze(0)
norm2_out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6) norm2_out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6)
y = weights.ffn_0.apply(norm2_out * norm2_weight + norm2_bias) y = weights.ffn_0.apply(norm2_out * norm2_weight + norm2_bias)
y = torch.nn.functional.gelu(y, approximate="tanh") y = torch.nn.functional.gelu(y, approximate="tanh")
y = weights.ffn_2.apply(y) y = weights.ffn_2.apply(y)
x = x + y * embed0[5].squeeze(0) x.add_(y * c_gate_msa.squeeze(0))
return x
def infer_block(self, weights, grid_sizes, x, embed0, seq_lens, freqs, context):
if embed0.dim() == 3:
modulation = weights.modulation.tensor.unsqueeze(2)
embed0 = (modulation + embed0).chunk(6, dim=1)
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = [ei.squeeze(1) for ei in embed0]
elif embed0.dim() == 2:
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (weights.modulation.tensor + embed0).chunk(6, dim=1)
x = self._infer_self_attn(
weights.compute_phases[1],
x,
shift_msa,
scale_msa,
gate_msa,
grid_sizes,
freqs,
seq_lens,
)
x = self._infer_cross_attn(weights.compute_phases[2], x, context)
x = self._infer_ffn(weights.compute_phases[3], x, c_shift_msa, c_scale_msa, c_gate_msa)
return x return x
...@@ -52,11 +52,6 @@ class WanModel: ...@@ -52,11 +52,6 @@ class WanModel:
else: else:
raise Exception(f"Unsuppotred parallel_attn_type") raise Exception(f"Unsuppotred parallel_attn_type")
if self.config["cpu_offload"]:
self.to_cpu()
else:
self.to_cuda()
def _init_infer_class(self): def _init_infer_class(self):
self.pre_infer_class = WanPreInfer self.pre_infer_class = WanPreInfer
self.post_infer_class = WanPostInfer self.post_infer_class = WanPostInfer
...@@ -188,7 +183,7 @@ class WanModel: ...@@ -188,7 +183,7 @@ class WanModel:
self.post_weight.to_cuda() self.post_weight.to_cuda()
embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=True) embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=True)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out) x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, *pre_infer_out)
noise_pred_cond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0] noise_pred_cond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
if self.config["feature_caching"] == "Tea": if self.config["feature_caching"] == "Tea":
...@@ -199,7 +194,7 @@ class WanModel: ...@@ -199,7 +194,7 @@ class WanModel:
if self.config["enable_cfg"]: if self.config["enable_cfg"]:
embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=False) embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=False)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out) x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, *pre_infer_out)
noise_pred_uncond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0] noise_pred_uncond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
if self.config["feature_caching"] == "Tea": if self.config["feature_caching"] == "Tea":
......
import torch
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER, LN_WEIGHT_REGISTER, CONV3D_WEIGHT_REGISTER from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER, LN_WEIGHT_REGISTER, CONV3D_WEIGHT_REGISTER
from lightx2v.common.modules.weight_module import WeightModule from lightx2v.common.modules.weight_module import WeightModule
......
import torch import torch
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER, LN_WEIGHT_REGISTER, RMS_WEIGHT_REGISTER, TENSOR_REGISTER, ATTN_WEIGHT_REGISTER from lightx2v.utils.registry_factory import (
MM_WEIGHT_REGISTER,
LN_WEIGHT_REGISTER,
RMS_WEIGHT_REGISTER,
TENSOR_REGISTER,
ATTN_WEIGHT_REGISTER,
)
from lightx2v.common.modules.weight_module import WeightModule, WeightModuleList from lightx2v.common.modules.weight_module import WeightModule, WeightModuleList
...@@ -26,57 +32,196 @@ class WanTransformerAttentionBlock(WeightModule): ...@@ -26,57 +32,196 @@ class WanTransformerAttentionBlock(WeightModule):
self.config = config self.config = config
self.quant_method = config["mm_config"].get("quant_method", None) self.quant_method = config["mm_config"].get("quant_method", None)
self.sparge = config.get("sparge", False) self.sparge = config.get("sparge", False)
self.register_parameter(
"modulation",
TENSOR_REGISTER["Default"](f"blocks.{self.block_index}.modulation"),
)
self.compute_phases = WeightModuleList(
[
WanSelfAttention(block_index, task, mm_type, config),
WanCrossAttention(block_index, task, mm_type, config),
WanFFN(block_index, task, mm_type, config),
]
)
self.add_module("self_attn_q", MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.self_attn.q.weight", f"blocks.{self.block_index}.self_attn.q.bias")) self.add_module("compute_phases", self.compute_phases)
self.add_module("self_attn_k", MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.self_attn.k.weight", f"blocks.{self.block_index}.self_attn.k.bias"))
self.add_module("self_attn_v", MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.self_attn.v.weight", f"blocks.{self.block_index}.self_attn.v.bias"))
self.add_module("self_attn_o", MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.self_attn.o.weight", f"blocks.{self.block_index}.self_attn.o.bias"))
self.add_module("self_attn_norm_q", RMS_WEIGHT_REGISTER["sgl-kernel"](f"blocks.{self.block_index}.self_attn.norm_q.weight"))
self.add_module("self_attn_norm_k", RMS_WEIGHT_REGISTER["sgl-kernel"](f"blocks.{self.block_index}.self_attn.norm_k.weight"))
self.add_module("norm3", LN_WEIGHT_REGISTER["Default"](f"blocks.{self.block_index}.norm3.weight", f"blocks.{self.block_index}.norm3.bias", eps=1e-6))
self.add_module("cross_attn_q", MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.cross_attn.q.weight", f"blocks.{self.block_index}.cross_attn.q.bias"))
self.add_module("cross_attn_k", MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.cross_attn.k.weight", f"blocks.{self.block_index}.cross_attn.k.bias"))
self.add_module("cross_attn_v", MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.cross_attn.v.weight", f"blocks.{self.block_index}.cross_attn.v.bias"))
self.add_module("cross_attn_o", MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.cross_attn.o.weight", f"blocks.{self.block_index}.cross_attn.o.bias"))
self.add_module("cross_attn_norm_q", RMS_WEIGHT_REGISTER["sgl-kernel"](f"blocks.{self.block_index}.cross_attn.norm_q.weight"))
self.add_module("cross_attn_norm_k", RMS_WEIGHT_REGISTER["sgl-kernel"](f"blocks.{self.block_index}.cross_attn.norm_k.weight"))
self.add_module("ffn_0", MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.ffn.0.weight", f"blocks.{self.block_index}.ffn.0.bias"))
self.add_module("ffn_2", MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.ffn.2.weight", f"blocks.{self.block_index}.ffn.2.bias"))
# attention weights section
if self.sparge:
assert self.config["sparge_ckpt"], "sparge_ckpt must be set when sparge is True"
self.add_module("self_attn_1", ATTN_WEIGHT_REGISTER["Sparge"](f"blocks.{self.block_index}"))
self.add_module("cross_attn_1", ATTN_WEIGHT_REGISTER[self.config["attention_type"]]())
else:
self.add_module("self_attn_1", ATTN_WEIGHT_REGISTER[self.config["attention_type"]]())
self.add_module("cross_attn_1", ATTN_WEIGHT_REGISTER[self.config["attention_type"]]())
if self.task == "i2v": # i2v
self.add_module("cross_attn_k_img", MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.cross_attn.k_img.weight", f"blocks.{self.block_index}.cross_attn.k_img.bias"))
self.add_module("cross_attn_v_img", MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.cross_attn.v_img.weight", f"blocks.{self.block_index}.cross_attn.v_img.bias"))
self.add_module("cross_attn_norm_k_img", RMS_WEIGHT_REGISTER["sgl-kernel"](f"blocks.{self.block_index}.cross_attn.norm_k_img.weight"))
# attention weights
self.add_module("cross_attn_2", ATTN_WEIGHT_REGISTER[self.config["attention_type"]]())
# load attn weights
class WanSelfAttention(WeightModule):
def __init__(self, block_index, task, mm_type, config):
super().__init__()
self.block_index = block_index
self.mm_type = mm_type
self.task = task
self.config = config
self.quant_method = config["mm_config"].get("quant_method", None)
self.sparge = config.get("sparge", False)
self.add_module(
"self_attn_q",
MM_WEIGHT_REGISTER[self.mm_type](
f"blocks.{self.block_index}.self_attn.q.weight",
f"blocks.{self.block_index}.self_attn.q.bias",
),
)
self.add_module(
"self_attn_k",
MM_WEIGHT_REGISTER[self.mm_type](
f"blocks.{self.block_index}.self_attn.k.weight",
f"blocks.{self.block_index}.self_attn.k.bias",
),
)
self.add_module(
"self_attn_v",
MM_WEIGHT_REGISTER[self.mm_type](
f"blocks.{self.block_index}.self_attn.v.weight",
f"blocks.{self.block_index}.self_attn.v.bias",
),
)
self.add_module(
"self_attn_o",
MM_WEIGHT_REGISTER[self.mm_type](
f"blocks.{self.block_index}.self_attn.o.weight",
f"blocks.{self.block_index}.self_attn.o.bias",
),
)
self.add_module(
"self_attn_norm_q",
RMS_WEIGHT_REGISTER["sgl-kernel"](f"blocks.{self.block_index}.self_attn.norm_q.weight"),
)
self.add_module(
"self_attn_norm_k",
RMS_WEIGHT_REGISTER["sgl-kernel"](f"blocks.{self.block_index}.self_attn.norm_k.weight"),
)
if self.sparge: if self.sparge:
assert self.config["sparge_ckpt"], "sparge_ckpt must be set when sparge is True" assert self.config["sparge_ckpt"], "sparge_ckpt must be set when sparge is True"
self.add_module(
"self_attn_1",
ATTN_WEIGHT_REGISTER["Sparge"](f"blocks.{self.block_index}"),
)
sparge_ckpt = torch.load(self.config["sparge_ckpt"]) sparge_ckpt = torch.load(self.config["sparge_ckpt"])
self.self_attn_1.load(sparge_ckpt) self.self_attn_1.load(sparge_ckpt)
else: else:
# do not load weights self.add_module("self_attn_1", ATTN_WEIGHT_REGISTER[self.config["attention_type"]]())
pass if self.quant_method in ["smoothquant", "awq"]:
self.register_parameter(
"smooth_norm1_weight",
TENSOR_REGISTER["Default"](f"blocks.{self.block_index}.affine_norm1.weight"),
)
self.register_parameter(
"smooth_norm1_bias",
TENSOR_REGISTER["Default"](f"blocks.{self.block_index}.affine_norm1.bias"),
)
class WanCrossAttention(WeightModule):
def __init__(self, block_index, task, mm_type, config):
super().__init__()
self.block_index = block_index
self.mm_type = mm_type
self.task = task
self.config = config
self.add_module(
"norm3",
LN_WEIGHT_REGISTER["Default"](
f"blocks.{self.block_index}.norm3.weight",
f"blocks.{self.block_index}.norm3.bias",
eps=1e-6,
),
)
self.add_module(
"cross_attn_q",
MM_WEIGHT_REGISTER[self.mm_type](
f"blocks.{self.block_index}.cross_attn.q.weight",
f"blocks.{self.block_index}.cross_attn.q.bias",
),
)
self.add_module(
"cross_attn_k",
MM_WEIGHT_REGISTER[self.mm_type](
f"blocks.{self.block_index}.cross_attn.k.weight",
f"blocks.{self.block_index}.cross_attn.k.bias",
),
)
self.add_module(
"cross_attn_v",
MM_WEIGHT_REGISTER[self.mm_type](
f"blocks.{self.block_index}.cross_attn.v.weight",
f"blocks.{self.block_index}.cross_attn.v.bias",
),
)
self.add_module(
"cross_attn_o",
MM_WEIGHT_REGISTER[self.mm_type](
f"blocks.{self.block_index}.cross_attn.o.weight",
f"blocks.{self.block_index}.cross_attn.o.bias",
),
)
self.add_module(
"cross_attn_norm_q",
RMS_WEIGHT_REGISTER["sgl-kernel"](f"blocks.{self.block_index}.cross_attn.norm_q.weight"),
)
self.add_module(
"cross_attn_norm_k",
RMS_WEIGHT_REGISTER["sgl-kernel"](f"blocks.{self.block_index}.cross_attn.norm_k.weight"),
)
self.add_module("cross_attn_1", ATTN_WEIGHT_REGISTER[self.config["attention_type"]]())
if self.config.task == "i2v":
self.add_module(
"cross_attn_k_img",
MM_WEIGHT_REGISTER[self.mm_type](
f"blocks.{self.block_index}.cross_attn.k_img.weight",
f"blocks.{self.block_index}.cross_attn.k_img.bias",
),
)
self.add_module(
"cross_attn_v_img",
MM_WEIGHT_REGISTER[self.mm_type](
f"blocks.{self.block_index}.cross_attn.v_img.weight",
f"blocks.{self.block_index}.cross_attn.v_img.bias",
),
)
self.add_module(
"cross_attn_norm_k_img",
RMS_WEIGHT_REGISTER["sgl-kernel"](f"blocks.{self.block_index}.cross_attn.norm_k_img.weight"),
)
self.add_module("cross_attn_2", ATTN_WEIGHT_REGISTER[self.config["attention_type"]]())
class WanFFN(WeightModule):
def __init__(self, block_index, task, mm_type, config):
super().__init__()
self.block_index = block_index
self.mm_type = mm_type
self.task = task
self.config = config
self.quant_method = config["mm_config"].get("quant_method", None)
self.add_module(
"ffn_0",
MM_WEIGHT_REGISTER[self.mm_type](
f"blocks.{self.block_index}.ffn.0.weight",
f"blocks.{self.block_index}.ffn.0.bias",
),
)
self.add_module(
"ffn_2",
MM_WEIGHT_REGISTER[self.mm_type](
f"blocks.{self.block_index}.ffn.2.weight",
f"blocks.{self.block_index}.ffn.2.bias",
),
)
# For smoothquant or awq
if self.quant_method in ["smoothquant", "awq"]: if self.quant_method in ["smoothquant", "awq"]:
self.register_parameter("smooth_norm1_weight", TENSOR_REGISTER["Default"](f"blocks.{self.block_index}.affine_norm1.weight")) self.register_parameter(
self.register_parameter("smooth_norm1_bias", TENSOR_REGISTER["Default"](f"blocks.{self.block_index}.affine_norm1.bias")) "smooth_norm2_weight",
self.register_parameter("smooth_norm2_weight", TENSOR_REGISTER["Default"](f"blocks.{self.block_index}.affine_norm3.weight")) TENSOR_REGISTER["Default"](f"blocks.{self.block_index}.affine_norm2.weight"),
self.register_parameter("smooth_norm2_bias", TENSOR_REGISTER["Default"](f"blocks.{self.block_index}.affine_norm3.bias")) )
elif self.quant_method is not None: self.register_parameter(
raise NotImplementedError(f"This {self.quant_method} method is not implemented yet.") "smooth_norm2_bias",
TENSOR_REGISTER["Default"](f"blocks.{self.block_index}.affine_norm2.bias"),
self.register_parameter("modulation", TENSOR_REGISTER["Default"](f"blocks.{self.block_index}.modulation")) )
...@@ -15,6 +15,7 @@ from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel ...@@ -15,6 +15,7 @@ from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel
from lightx2v.models.networks.wan.model import WanModel from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
from lightx2v.models.video_encoders.hf.wan.vae_tiny import WanVAE_tiny
import torch.distributed as dist import torch.distributed as dist
from lightx2v.utils.memory_profiler import peak_memory_decorator from lightx2v.utils.memory_profiler import peak_memory_decorator
from loguru import logger from loguru import logger
...@@ -43,6 +44,8 @@ class WanRunner(DefaultRunner): ...@@ -43,6 +44,8 @@ class WanRunner(DefaultRunner):
checkpoint_path=os.path.join(self.config.model_path, "models_t5_umt5-xxl-enc-bf16.pth"), checkpoint_path=os.path.join(self.config.model_path, "models_t5_umt5-xxl-enc-bf16.pth"),
tokenizer_path=os.path.join(self.config.model_path, "google/umt5-xxl"), tokenizer_path=os.path.join(self.config.model_path, "google/umt5-xxl"),
shard_fn=None, shard_fn=None,
cpu_offload=self.config.cpu_offload,
offload_granularity=self.config.get("text_encoder_offload_granularity", "model"),
) )
text_encoders = [text_encoder] text_encoders = [text_encoder]
model = WanModel(self.config.model_path, self.config, init_device) model = WanModel(self.config.model_path, self.config, init_device)
...@@ -53,11 +56,19 @@ class WanRunner(DefaultRunner): ...@@ -53,11 +56,19 @@ class WanRunner(DefaultRunner):
lora_wrapper.apply_lora(lora_name, self.config.strength_model) lora_wrapper.apply_lora(lora_name, self.config.strength_model)
logger.info(f"Loaded LoRA: {lora_name}") logger.info(f"Loaded LoRA: {lora_name}")
vae_model = WanVAE( if self.config.get("tiny_vae", False):
vae_pth=os.path.join(self.config.model_path, "Wan2.1_VAE.pth"), vae_model = WanVAE_tiny(
device=init_device, vae_pth=self.config.tiny_vae_path,
parallel=self.config.parallel_vae, device=init_device,
) )
vae_model = vae_model.to("cuda")
else:
vae_model = WanVAE(
vae_pth=os.path.join(self.config.model_path, "Wan2.1_VAE.pth"),
device=init_device,
parallel=self.config.parallel_vae,
use_tiling=self.config.get("use_tiling_vae", False),
)
if self.config.task == "i2v": if self.config.task == "i2v":
image_encoder = CLIPModel( image_encoder = CLIPModel(
dtype=torch.float16, dtype=torch.float16,
...@@ -68,6 +79,14 @@ class WanRunner(DefaultRunner): ...@@ -68,6 +79,14 @@ class WanRunner(DefaultRunner):
), ),
tokenizer_path=os.path.join(self.config.model_path, "xlm-roberta-large"), tokenizer_path=os.path.join(self.config.model_path, "xlm-roberta-large"),
) )
if self.config.get("tiny_vae", False):
org_vae = WanVAE(
vae_pth=os.path.join(self.config.model_path, "Wan2.1_VAE.pth"),
device=init_device,
parallel=self.config.parallel_vae,
use_tiling=self.config.get("use_tiling_vae", False),
)
image_encoder = [image_encoder, org_vae]
return model, text_encoders, vae_model, image_encoder return model, text_encoders, vae_model, image_encoder
...@@ -84,17 +103,21 @@ class WanRunner(DefaultRunner): ...@@ -84,17 +103,21 @@ class WanRunner(DefaultRunner):
def run_text_encoder(self, text, text_encoders, config, image_encoder_output): def run_text_encoder(self, text, text_encoders, config, image_encoder_output):
text_encoder_output = {} text_encoder_output = {}
n_prompt = config.get("negative_prompt", "") n_prompt = config.get("negative_prompt", "")
context = text_encoders[0].infer([text], config) context = text_encoders[0].infer([text])
context_null = text_encoders[0].infer([n_prompt if n_prompt else ""], config) context_null = text_encoders[0].infer([n_prompt if n_prompt else ""])
text_encoder_output["context"] = context text_encoder_output["context"] = context
text_encoder_output["context_null"] = context_null text_encoder_output["context_null"] = context_null
return text_encoder_output return text_encoder_output
@peak_memory_decorator @peak_memory_decorator
def run_image_encoder(self, config, image_encoder, vae_model): def run_image_encoder(self, config, image_encoder, vae_model):
if self.config.get("tiny_vae", False):
clip_image_encoder, vae_image_encoder = image_encoder[0], image_encoder[1]
else:
clip_image_encoder, vae_image_encoder = image_encoder, vae_model
img = Image.open(config.image_path).convert("RGB") img = Image.open(config.image_path).convert("RGB")
img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda() img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
clip_encoder_out = image_encoder.visual([img[:, None, :, :]], config).squeeze(0).to(torch.bfloat16) clip_encoder_out = clip_image_encoder.visual([img[:, None, :, :]], config).squeeze(0).to(torch.bfloat16)
h, w = img.shape[1:] h, w = img.shape[1:]
aspect_ratio = h / w aspect_ratio = h / w
max_area = config.target_height * config.target_width max_area = config.target_height * config.target_width
...@@ -111,7 +134,7 @@ class WanRunner(DefaultRunner): ...@@ -111,7 +134,7 @@ class WanRunner(DefaultRunner):
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
msk = msk.transpose(1, 2)[0] msk = msk.transpose(1, 2)[0]
vae_encode_out = vae_model.encode( vae_encode_out = vae_image_encoder.encode(
[ [
torch.concat( torch.concat(
[ [
...@@ -131,14 +154,14 @@ class WanRunner(DefaultRunner): ...@@ -131,14 +154,14 @@ class WanRunner(DefaultRunner):
if self.config.task == "i2v": if self.config.task == "i2v":
self.config.target_shape = ( self.config.target_shape = (
num_channels_latents, num_channels_latents,
(self.config.target_video_length - 1) // 4 + 1, (self.config.target_video_length - 1) // self.config.vae_stride[0] + 1,
self.config.lat_h, self.config.lat_h,
self.config.lat_w, self.config.lat_w,
) )
elif self.config.task == "t2v": elif self.config.task == "t2v":
self.config.target_shape = ( self.config.target_shape = (
num_channels_latents, num_channels_latents,
(self.config.target_video_length - 1) // 4 + 1, (self.config.target_video_length - 1) // self.config.vae_stride[0] + 1,
int(self.config.target_height) // self.config.vae_stride[1], int(self.config.target_height) // self.config.vae_stride[1],
int(self.config.target_width) // self.config.vae_stride[2], int(self.config.target_width) // self.config.vae_stride[2],
) )
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm
from collections import namedtuple
import gc
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:32,expandable_segments:True"
DecoderResult = namedtuple("DecoderResult", ("frame", "memory"))
TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index"))
def conv(n_in, n_out, **kwargs):
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
class Clamp(nn.Module):
def forward(self, x):
return torch.tanh(x / 3) * 3
class MemBlock(nn.Module):
def __init__(self, n_in, n_out):
super().__init__()
self.conv = nn.Sequential(conv(n_in * 2, n_out), nn.ReLU(inplace=True), conv(n_out, n_out), nn.ReLU(inplace=True), conv(n_out, n_out))
self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
self.act = nn.ReLU(inplace=True)
def forward(self, x, past):
return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x))
class TPool(nn.Module):
def __init__(self, n_f, stride):
super().__init__()
self.stride = stride
self.conv = nn.Conv2d(n_f * stride, n_f, 1, bias=False)
def forward(self, x):
_NT, C, H, W = x.shape
return self.conv(x.reshape(-1, self.stride * C, H, W))
class TGrow(nn.Module):
def __init__(self, n_f, stride):
super().__init__()
self.stride = stride
self.conv = nn.Conv2d(n_f, n_f * stride, 1, bias=False)
def forward(self, x):
_NT, C, H, W = x.shape
x = self.conv(x)
return x.reshape(-1, C, H, W)
def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
"""
Apply a sequential model with memblocks to the given input.
Args:
- model: nn.Sequential of blocks to apply
- x: input data, of dimensions NTCHW
- parallel: if True, parallelize over timesteps (fast but uses O(T) memory)
if False, each timestep will be processed sequentially (slow but uses O(1) memory)
- show_progress_bar: if True, enables tqdm progressbar display
Returns NTCHW tensor of output data.
"""
assert x.ndim == 5, f"TAEHV operates on NTCHW tensors, but got {x.ndim}-dim tensor"
N, T, C, H, W = x.shape
if parallel:
x = x.reshape(N * T, C, H, W)
# parallel over input timesteps, iterate over blocks
for b in tqdm(model, disable=not show_progress_bar):
if isinstance(b, MemBlock):
NT, C, H, W = x.shape
T = NT // N
_x = x.reshape(N, T, C, H, W)
mem = F.pad(_x, (0, 0, 0, 0, 0, 0, 1, 0), value=0)[:, :T].reshape(x.shape)
x = b(x, mem)
else:
x = b(x)
NT, C, H, W = x.shape
T = NT // N
x = x.view(N, T, C, H, W)
else:
# TODO(oboerbohan): at least on macos this still gradually uses more memory during decode...
# need to fix :(
out = []
# iterate over input timesteps and also iterate over blocks.
# because of the cursed TPool/TGrow blocks, this is not a nested loop,
# it's actually a ***graph traversal*** problem! so let's make a queue
work_queue = [TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(N, T * C, H, W).chunk(T, dim=1))]
# in addition to manually managing our queue, we also need to manually manage our progressbar.
# we'll update it for every source node that we consume.
progress_bar = tqdm(range(T), disable=not show_progress_bar)
# we'll also need a separate addressable memory per node as well
mem = [None] * len(model)
while work_queue:
xt, i = work_queue.pop(0)
if i == 0:
# new source node consumed
progress_bar.update(1)
if i == len(model):
# reached end of the graph, append result to output list
out.append(xt)
else:
# fetch the block to process
b = model[i]
if isinstance(b, MemBlock):
# mem blocks are simple since we're visiting the graph in causal order
if mem[i] is None:
xt_new = b(xt, xt * 0)
mem[i] = xt
else:
xt_new = b(xt, mem[i])
mem[i].copy_(xt) # inplace might reduce mysterious pytorch memory allocations? doesn't help though
# add successor to work queue
work_queue.insert(0, TWorkItem(xt_new, i + 1))
elif isinstance(b, TPool):
# pool blocks are miserable
if mem[i] is None:
mem[i] = [] # pool memory is itself a queue of inputs to pool
mem[i].append(xt)
if len(mem[i]) > b.stride:
# pool mem is in invalid state, we should have pooled before this
raise ValueError("???")
elif len(mem[i]) < b.stride:
# pool mem is not yet full, go back to processing the work queue
pass
else:
# pool mem is ready, run the pool block
N, C, H, W = xt.shape
xt = b(torch.cat(mem[i], 1).view(N * b.stride, C, H, W))
# reset the pool mem
mem[i] = []
# add successor to work queue
work_queue.insert(0, TWorkItem(xt, i + 1))
elif isinstance(b, TGrow):
xt = b(xt)
NT, C, H, W = xt.shape
# each tgrow has multiple successor nodes
for xt_next in reversed(xt.view(N, b.stride * C, H, W).chunk(b.stride, 1)):
# add successor to work queue
work_queue.insert(0, TWorkItem(xt_next, i + 1))
else:
# normal block with no funny business
xt = b(xt)
# add successor to work queue
work_queue.insert(0, TWorkItem(xt, i + 1))
progress_bar.close()
x = torch.stack(out, 1)
return x
class TAEHV(nn.Module):
latent_channels = 16
image_channels = 3
def __init__(self, checkpoint_path="taehv.pth", decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True)):
"""Initialize pretrained TAEHV from the given checkpoint.
Arg:
checkpoint_path: path to weight file to load. taehv.pth for Hunyuan, taew2_1.pth for Wan 2.1.
decoder_time_upscale: whether temporal upsampling is enabled for each block. upsampling can be disabled for a cheaper preview.
decoder_space_upscale: whether spatial upsampling is enabled for each block. upsampling can be disabled for a cheaper preview.
"""
super().__init__()
self.encoder = nn.Sequential(
conv(TAEHV.image_channels, 64),
nn.ReLU(inplace=True),
TPool(64, 2),
conv(64, 64, stride=2, bias=False),
MemBlock(64, 64),
MemBlock(64, 64),
MemBlock(64, 64),
TPool(64, 2),
conv(64, 64, stride=2, bias=False),
MemBlock(64, 64),
MemBlock(64, 64),
MemBlock(64, 64),
TPool(64, 1),
conv(64, 64, stride=2, bias=False),
MemBlock(64, 64),
MemBlock(64, 64),
MemBlock(64, 64),
conv(64, TAEHV.latent_channels),
)
n_f = [256, 128, 64, 64]
self.frames_to_trim = 2 ** sum(decoder_time_upscale) - 1
self.decoder = nn.Sequential(
Clamp(),
conv(TAEHV.latent_channels, n_f[0]),
nn.ReLU(inplace=True),
MemBlock(n_f[0], n_f[0]),
MemBlock(n_f[0], n_f[0]),
MemBlock(n_f[0], n_f[0]),
nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1),
TGrow(n_f[0], 1),
conv(n_f[0], n_f[1], bias=False),
MemBlock(n_f[1], n_f[1]),
MemBlock(n_f[1], n_f[1]),
MemBlock(n_f[1], n_f[1]),
nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1),
TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1),
conv(n_f[1], n_f[2], bias=False),
MemBlock(n_f[2], n_f[2]),
MemBlock(n_f[2], n_f[2]),
MemBlock(n_f[2], n_f[2]),
nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1),
TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1),
conv(n_f[2], n_f[3], bias=False),
nn.ReLU(inplace=True),
conv(n_f[3], TAEHV.image_channels),
)
if checkpoint_path is not None:
self.load_state_dict(self.patch_tgrow_layers(torch.load(checkpoint_path, map_location="cpu", weights_only=True)))
def patch_tgrow_layers(self, sd):
"""Patch TGrow layers to use a smaller kernel if needed.
Args:
sd: state dict to patch
"""
new_sd = self.state_dict()
for i, layer in enumerate(self.decoder):
if isinstance(layer, TGrow):
key = f"decoder.{i}.conv.weight"
if sd[key].shape[0] > new_sd[key].shape[0]:
# take the last-timestep output channels
sd[key] = sd[key][-new_sd[key].shape[0] :]
return sd
def encode_video(self, x, parallel=True, show_progress_bar=True):
"""Encode a sequence of frames.
Args:
x: input NTCHW RGB (C=3) tensor with values in [0, 1].
parallel: if True, all frames will be processed at once.
(this is faster but may require more memory).
if False, frames will be processed sequentially.
Returns NTCHW latent tensor with ~Gaussian values.
"""
return apply_model_with_memblocks(self.encoder, x, parallel, show_progress_bar)
def decode_video(self, x, parallel=True, show_progress_bar=True):
"""Decode a sequence of frames.
Args:
x: input NTCHW latent (C=12) tensor with ~Gaussian values.
parallel: if True, all frames will be processed at once.
(this is faster but may require more memory).
if False, frames will be processed sequentially.
Returns NTCHW RGB tensor with ~[0, 1] values.
"""
x = apply_model_with_memblocks(self.decoder, x, parallel, show_progress_bar)
return x[:, self.frames_to_trim :]
def forward(self, x):
return self.c(x)
@torch.no_grad()
def main():
"""Run TAEHV roundtrip reconstruction on the given video paths."""
import sys
import cv2 # no highly esteemed deed is commemorated here
class VideoTensorReader:
def __init__(self, video_file_path):
self.cap = cv2.VideoCapture(video_file_path)
assert self.cap.isOpened(), f"Could not load {video_file_path}"
self.fps = self.cap.get(cv2.CAP_PROP_FPS)
def __iter__(self):
return self
def __next__(self):
ret, frame = self.cap.read()
if not ret:
self.cap.release()
raise StopIteration # End of video or error
return torch.from_numpy(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)).permute(2, 0, 1) # BGR HWC -> RGB CHW
class VideoTensorWriter:
def __init__(self, video_file_path, width_height, fps=30):
self.writer = cv2.VideoWriter(video_file_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, width_height)
assert self.writer.isOpened(), f"Could not create writer for {video_file_path}"
def write(self, frame_tensor):
assert frame_tensor.ndim == 3 and frame_tensor.shape[0] == 3, f"{frame_tensor.shape}??"
self.writer.write(cv2.cvtColor(frame_tensor.permute(1, 2, 0).numpy(), cv2.COLOR_RGB2BGR)) # RGB CHW -> BGR HWC
def __del__(self):
if hasattr(self, "writer"):
self.writer.release()
dev = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
dtype = torch.float16
print("Using device", dev, "and dtype", dtype)
taehv = TAEHV().to(dev, dtype)
for video_path in sys.argv[1:]:
print(f"Processing {video_path}...")
video_in = VideoTensorReader(video_path)
video = torch.stack(list(video_in), 0)[None]
vid_dev = video.to(dev, dtype).div_(255.0)
# convert to device tensor
if video.numel() < 100_000_000:
print(f" {video_path} seems small enough, will process all frames in parallel")
# convert to device tensor
vid_enc = taehv.encode_video(vid_dev)
print(f" Encoded {video_path} -> {vid_enc.shape}. Decoding...")
vid_dec = taehv.decode_video(vid_enc)
print(f" Decoded {video_path} -> {vid_dec.shape}")
else:
print(f" {video_path} seems large, will process each frame sequentially")
# convert to device tensor
vid_enc = taehv.encode_video(vid_dev, parallel=False)
print(f" Encoded {video_path} -> {vid_enc.shape}. Decoding...")
vid_dec = taehv.decode_video(vid_enc, parallel=False)
print(f" Decoded {video_path} -> {vid_dec.shape}")
video_out_path = video_path + ".reconstructed_by_taehv.mp4"
video_out = VideoTensorWriter(video_out_path, (vid_dec.shape[-1], vid_dec.shape[-2]), fps=int(round(video_in.fps)))
for frame in vid_dec.clamp_(0, 1).mul_(255).round_().byte().cpu()[0]:
video_out.write(frame)
print(f" Saved to {video_out_path}")
if __name__ == "__main__":
main()
...@@ -517,7 +517,15 @@ class WanVAE_(nn.Module): ...@@ -517,7 +517,15 @@ class WanVAE_(nn.Module):
self.attn_scales = attn_scales self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample self.temperal_downsample = temperal_downsample
self.temperal_upsample = temperal_downsample[::-1] self.temperal_upsample = temperal_downsample[::-1]
self.spatial_compression_ratio = 2 ** len(self.temperal_downsample)
# The minimal tile height and width for spatial tiling to be used
self.tile_sample_min_height = 256
self.tile_sample_min_width = 256
# The minimal distance between two spatial tiles
self.tile_sample_stride_height = 192
self.tile_sample_stride_width = 192
# modules # modules
self.encoder = Encoder3d( self.encoder = Encoder3d(
dim, dim,
...@@ -546,6 +554,134 @@ class WanVAE_(nn.Module): ...@@ -546,6 +554,134 @@ class WanVAE_(nn.Module):
x_recon = self.decode(z) x_recon = self.decode(z)
return x_recon, mu, log_var return x_recon, mu, log_var
def blend_v(self, a, b, blend_extent):
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
for y in range(blend_extent):
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent)
return b
def blend_h(self, a, b, blend_extent):
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
for x in range(blend_extent):
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent)
return b
def tiled_encode(self, x, scale):
_, _, num_frames, height, width = x.shape
latent_height = height // self.spatial_compression_ratio
latent_width = width // self.spatial_compression_ratio
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
blend_height = tile_latent_min_height - tile_latent_stride_height
blend_width = tile_latent_min_width - tile_latent_stride_width
# Split x into overlapping tiles and encode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, height, self.tile_sample_stride_height):
row = []
for j in range(0, width, self.tile_sample_stride_width):
self.clear_cache()
time = []
frame_range = 1 + (num_frames - 1) // 4
for k in range(frame_range):
self._enc_conv_idx = [0]
if k == 0:
tile = x[:, :, :1, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
else:
tile = x[
:,
:,
1 + 4 * (k - 1) : 1 + 4 * k,
i : i + self.tile_sample_min_height,
j : j + self.tile_sample_min_width,
]
tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
mu, log_var = self.conv1(tile).chunk(2, dim=1)
if isinstance(scale[0], torch.Tensor):
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1)
else:
mu = (mu - scale[0]) * scale[1]
time.append(mu)
row.append(torch.cat(time, dim=2))
rows.append(row)
self.clear_cache()
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])
result_rows.append(torch.cat(result_row, dim=-1))
enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
return enc
def tiled_decode(self, z, scale):
if isinstance(scale[0], torch.Tensor):
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1)
else:
z = z / scale[1] + scale[0]
_, _, num_frames, height, width = z.shape
sample_height = height * self.spatial_compression_ratio
sample_width = width * self.spatial_compression_ratio
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
# Split z into overlapping tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, height, tile_latent_stride_height):
row = []
for j in range(0, width, tile_latent_stride_width):
self.clear_cache()
time = []
for k in range(num_frames):
self._conv_idx = [0]
tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
tile = self.conv2(tile)
decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx)
time.append(decoded)
row.append(torch.cat(time, dim=2))
rows.append(row)
self.clear_cache()
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
result_rows.append(torch.cat(result_row, dim=-1))
dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
return dec
def encode(self, x, scale): def encode(self, x, scale):
self.clear_cache() self.clear_cache()
## cache ## cache
...@@ -660,10 +796,12 @@ class WanVAE: ...@@ -660,10 +796,12 @@ class WanVAE:
dtype=torch.float, dtype=torch.float,
device="cuda", device="cuda",
parallel=False, parallel=False,
use_tiling=False,
): ):
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
self.parallel = parallel self.parallel = parallel
self.use_tiling = use_tiling
mean = [ mean = [
-0.7571, -0.7571,
...@@ -735,7 +873,10 @@ class WanVAE: ...@@ -735,7 +873,10 @@ class WanVAE:
if args.cpu_offload: if args.cpu_offload:
self.to_cuda() self.to_cuda()
out = [self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0) for u in videos] if self.use_tiling:
out = [self.model.tiled_encode(u.unsqueeze(0), self.scale).float().squeeze(0) for u in videos]
else:
out = [self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0) for u in videos]
if args.cpu_offload: if args.cpu_offload:
self.to_cpu() self.to_cpu()
...@@ -806,6 +947,8 @@ class WanVAE: ...@@ -806,6 +947,8 @@ class WanVAE:
else: else:
logger.info("Fall back to naive decode mode") logger.info("Fall back to naive decode mode")
images = self.model.decode(zs.unsqueeze(0), self.scale).float().clamp_(-1, 1) images = self.model.decode(zs.unsqueeze(0), self.scale).float().clamp_(-1, 1)
elif self.use_tiling:
images = self.model.tiled_decode(zs.unsqueeze(0), self.scale).float().clamp_(-1, 1)
else: else:
images = self.model.decode(zs.unsqueeze(0), self.scale).float().clamp_(-1, 1) images = self.model.decode(zs.unsqueeze(0), self.scale).float().clamp_(-1, 1)
......
import torch
import torch.nn as nn
from ..tae import TAEHV
from lightx2v.utils.memory_profiler import peak_memory_decorator
class DotDict(dict):
__getattr__ = dict.__getitem__
__setattr__ = dict.__setitem__
class WanVAE_tiny(nn.Module):
def __init__(self, vae_pth="taew2_1.pth", dtype=torch.bfloat16, device="cuda"):
super().__init__()
self.dtype = dtype
self.device = torch.device("cuda")
self.taehv = TAEHV(vae_pth).to(self.dtype)
self.temperal_downsample = [True, True, False]
self.config = DotDict(scaling_factor=1.0, latents_mean=torch.zeros(16), z_dim=16, latents_std=torch.ones(16))
@peak_memory_decorator
@torch.no_grad()
def decode(self, latents, generator=None, return_dict=None, config=None):
latents = latents.unsqueeze(0)
n, c, t, h, w = latents.shape
# low-memory, set parallel=True for faster + higher memory
return self.taehv.decode_video(latents.transpose(1, 2).to(self.dtype), parallel=False).transpose(1, 2).mul_(2).sub_(1)
import json import json
import os import os
from easydict import EasyDict from easydict import EasyDict
from loguru import logger
def get_default_config(): def get_default_config():
...@@ -38,4 +39,8 @@ def set_config(args): ...@@ -38,4 +39,8 @@ def set_config(args):
model_config = json.load(f) model_config = json.load(f)
config.update(model_config) config.update(model_config)
if config.target_video_length % config.vae_stride[0] != 1:
logger.warning(f"`num_frames - 1` has to be divisible by {config.vae_stride[0]}. Rounding to the nearest number.")
config.target_video_length = config.target_video_length // config.vae_stride[0] * config.vae_stride[0] + 1
return config return config
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment