Commit 5c241f86 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Support run load memory machine, fix some bugs and reconstruct quantizaton. (#61)



* reconstruct quantization and fix memory leak bug.

* Support lazy load inference.

* reconstruct quantization

* Fix hunyuan bugs

* deleted tmp file

---------
Co-authored-by: default avatarroot <root@pt-c0b333b3a1834e81a0d4d5f412c6ffa1-worker-0.pt-c0b333b3a1834e81a0d4d5f412c6ffa1.ns-devsft-3460edd0.svc.cluster.local>
Co-authored-by: default avatargushiqiao <gushqiaio@sensetime.com>
Co-authored-by: default avatargushiqiao <gushiqiao@sensetime.com>
parent b7d2d43f
import torch
import torch.nn as nn
from vllm import _custom_ops as ops
try:
import q8_kernels.functional as Q8F
except ImportError:
Q8F = None
class QuantLinearInt8(nn.Module):
def __init__(self, in_features, out_features, bias=True):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.int8))
self.register_buffer("weight_scale", torch.empty((out_features, 1), dtype=torch.float32))
if bias:
self.register_buffer("bias", torch.empty(out_features, dtype=torch.float32))
else:
self.register_buffer("bias", None)
def act_quant_func(self, x):
input_tensor_quant, input_tensor_scale, _ = ops.scaled_int8_quant(x, scale=None, azp=None, symmetric=True)
return input_tensor_quant, input_tensor_scale
def forward(self, x):
input_tensor_quant, input_tensor_scale = self.act_quant_func(x)
output_tensor = Q8F.linear.q8_linear(
input_tensor_quant,
self.weight,
self.bias.float() if self.bias is not None else None,
input_tensor_scale,
self.weight_scale.float(),
fuse_gelu=False,
out_dtype=torch.bfloat16,
)
return output_tensor
......@@ -9,6 +9,8 @@ import torch.nn.functional as F
from .tokenizer import HuggingfaceTokenizer
from loguru import logger
from lightx2v.models.input_encoders.hf.q_linear import QuantLinearInt8
__all__ = [
"T5Model",
......@@ -63,7 +65,7 @@ class T5LayerNorm(nn.Module):
class T5Attention(nn.Module):
def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
def __init__(self, dim, dim_attn, num_heads, dropout=0.1, quantized=False, quant_scheme=None):
assert dim_attn % num_heads == 0
super(T5Attention, self).__init__()
self.dim = dim
......@@ -71,11 +73,17 @@ class T5Attention(nn.Module):
self.num_heads = num_heads
self.head_dim = dim_attn // num_heads
if quantized:
if quant_scheme == "int8":
linear_cls = QuantLinearInt8
else:
linear_cls = nn.Linear
# layers
self.q = nn.Linear(dim, dim_attn, bias=False)
self.k = nn.Linear(dim, dim_attn, bias=False)
self.v = nn.Linear(dim, dim_attn, bias=False)
self.o = nn.Linear(dim_attn, dim, bias=False)
self.q = linear_cls(dim, dim_attn, bias=False)
self.k = linear_cls(dim, dim_attn, bias=False)
self.v = linear_cls(dim, dim_attn, bias=False)
self.o = linear_cls(dim_attn, dim, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x, context=None, mask=None, pos_bias=None):
......@@ -104,7 +112,7 @@ class T5Attention(nn.Module):
# compute attention (T5 does not use scaling)
attn = torch.einsum("binc,bjnc->bnij", q, k) + attn_bias
attn = F.softmax(attn.float(), dim=-1).type_as(attn)
attn = F.softmax(attn.float(), dim=-1).to(torch.bfloat16)
x = torch.einsum("bnij,bjnc->binc", attn, v)
# output
......@@ -115,15 +123,20 @@ class T5Attention(nn.Module):
class T5FeedForward(nn.Module):
def __init__(self, dim, dim_ffn, dropout=0.1):
def __init__(self, dim, dim_ffn, dropout=0.1, quantized=False, quant_scheme=None):
super(T5FeedForward, self).__init__()
self.dim = dim
self.dim_ffn = dim_ffn
if quantized:
if quant_scheme == "int8":
linear_cls = QuantLinearInt8
else:
linear_cls = nn.Linear
# layers
self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
self.gate = nn.Sequential(linear_cls(dim, dim_ffn, bias=False), GELU())
self.fc1 = linear_cls(dim, dim_ffn, bias=False)
self.fc2 = linear_cls(dim_ffn, dim, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
......@@ -135,16 +148,7 @@ class T5FeedForward(nn.Module):
class T5SelfAttention(nn.Module):
def __init__(
self,
dim,
dim_attn,
dim_ffn,
num_heads,
num_buckets,
shared_pos=True,
dropout=0.1,
):
def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.1, quantized=False, quant_scheme=None):
super(T5SelfAttention, self).__init__()
self.dim = dim
self.dim_attn = dim_attn
......@@ -155,9 +159,9 @@ class T5SelfAttention(nn.Module):
# layers
self.norm1 = T5LayerNorm(dim)
self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
self.attn = T5Attention(dim, dim_attn, num_heads, dropout, quantized, quant_scheme)
self.norm2 = T5LayerNorm(dim)
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
self.ffn = T5FeedForward(dim, dim_ffn, dropout, quantized, quant_scheme)
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
def forward(self, x, mask=None, pos_bias=None):
......@@ -244,20 +248,9 @@ class T5RelativeEmbedding(nn.Module):
class T5Encoder(nn.Module):
def __init__(
self,
vocab,
dim,
dim_attn,
dim_ffn,
num_heads,
num_layers,
num_buckets,
shared_pos=True,
dropout=0.1,
cpu_offload=False,
):
def __init__(self, vocab, dim, dim_attn, dim_ffn, num_heads, num_layers, num_buckets, shared_pos=True, dropout=0.1, cpu_offload=False, quantized=False, quant_scheme=None):
super(T5Encoder, self).__init__()
self.cpu_offload = cpu_offload
self.dim = dim
self.dim_attn = dim_attn
......@@ -266,16 +259,17 @@ class T5Encoder(nn.Module):
self.num_layers = num_layers
self.num_buckets = num_buckets
self.shared_pos = shared_pos
self.quant_scheme = quant_scheme
# layers
self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True) if shared_pos else None
self.dropout = nn.Dropout(dropout)
self.blocks = nn.ModuleList([T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout) for _ in range(num_layers)])
self.blocks = nn.ModuleList([T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout, quantized, quant_scheme) for _ in range(num_layers)])
self.norm = T5LayerNorm(dim)
# initialize weights
self.apply(init_weights)
# self.apply(init_weights)
def forward(self, ids, mask=None):
if self.cpu_offload:
......@@ -301,7 +295,7 @@ class T5Encoder(nn.Module):
if self.cpu_offload:
self.norm = self.norm.cpu()
x = self.dropout(x)
return x
return x.to(torch.bfloat16)
class T5Decoder(nn.Module):
......@@ -480,11 +474,17 @@ class T5EncoderModel:
shard_fn=None,
cpu_offload=False,
offload_granularity="model",
t5_quantized=False,
t5_quantized_ckpt=None,
quant_scheme=None,
):
self.text_len = text_len
self.dtype = dtype
self.device = device
self.checkpoint_path = checkpoint_path
if t5_quantized_ckpt is not None and t5_quantized:
self.checkpoint_path = t5_quantized_ckpt
else:
self.checkpoint_path = checkpoint_path
self.tokenizer_path = tokenizer_path
self.offload_granularity = offload_granularity
......@@ -493,20 +493,23 @@ class T5EncoderModel:
if self.cpu_offload:
assert self.offload_granularity in ["block", "model"]
# init model
model = (
umt5_xxl(
encoder_only=True,
return_tokenizer=False,
dtype=dtype,
device=device,
cpu_offload=cpu_offload if self.offload_granularity == "block" else False,
cpu_offload=(cpu_offload if self.offload_granularity == "block" else False),
quantized=t5_quantized,
quant_scheme=quant_scheme,
)
.eval()
.requires_grad_(False)
)
logging.info(f"loading {checkpoint_path}")
model.load_state_dict(torch.load(checkpoint_path, map_location="cpu", weights_only=True))
logger.info(f"Loading weights from {self.checkpoint_path}")
model.load_state_dict(torch.load(self.checkpoint_path, map_location="cpu", weights_only=True))
self.model = model
if shard_fn is not None:
self.model = shard_fn(self.model, sync_module_states=False)
......
......@@ -9,10 +9,9 @@ import torch.nn.functional as F
import torchvision.transforms as T
from lightx2v.attentions import attention
from lightx2v.models.input_encoders.hf.t5.tokenizer import HuggingfaceTokenizer
from loguru import logger
from lightx2v.models.input_encoders.hf.q_linear import QuantLinearInt8
from .xlm_roberta import XLMRoberta
__all__ = [
"XLMRobertaCLIP",
......@@ -48,7 +47,7 @@ class LayerNorm(nn.LayerNorm):
class SelfAttention(nn.Module):
def __init__(self, dim, num_heads, causal=False, attn_dropout=0.0, proj_dropout=0.0):
def __init__(self, dim, num_heads, causal=False, attn_dropout=0.0, proj_dropout=0.0, quantized=False, quant_scheme=None):
assert dim % num_heads == 0
super().__init__()
self.dim = dim
......@@ -59,8 +58,14 @@ class SelfAttention(nn.Module):
self.proj_dropout = proj_dropout
# layers
self.to_qkv = nn.Linear(dim, dim * 3)
self.proj = nn.Linear(dim, dim)
if quantized:
if quant_scheme == "int8":
linear_cls = QuantLinearInt8
else:
linear_cls = nn.Linear
self.to_qkv = linear_cls(dim, dim * 3)
self.proj = linear_cls(dim, dim)
def forward(self, x):
"""
......@@ -86,7 +91,6 @@ class SwiGLU(nn.Module):
super().__init__()
self.dim = dim
self.mid_dim = mid_dim
# layers
self.fc1 = nn.Linear(dim, mid_dim)
self.fc2 = nn.Linear(dim, mid_dim)
......@@ -99,7 +103,7 @@ class SwiGLU(nn.Module):
class AttentionBlock(nn.Module):
def __init__(self, dim, mlp_ratio, num_heads, post_norm=False, causal=False, activation="quick_gelu", attn_dropout=0.0, proj_dropout=0.0, norm_eps=1e-5):
def __init__(self, dim, mlp_ratio, num_heads, post_norm=False, causal=False, activation="quick_gelu", attn_dropout=0.0, proj_dropout=0.0, norm_eps=1e-5, quantized=False, quant_scheme=None):
assert activation in ["quick_gelu", "gelu", "swi_glu"]
super().__init__()
self.dim = dim
......@@ -110,13 +114,19 @@ class AttentionBlock(nn.Module):
self.norm_eps = norm_eps
# layers
if quantized:
if quant_scheme == "int8":
linear_cls = QuantLinearInt8
else:
linear_cls = nn.Linear
self.norm1 = LayerNorm(dim, eps=norm_eps)
self.attn = SelfAttention(dim, num_heads, causal, attn_dropout, proj_dropout)
self.attn = SelfAttention(dim, num_heads, causal, attn_dropout, proj_dropout, quantized, quant_scheme)
self.norm2 = LayerNorm(dim, eps=norm_eps)
if activation == "swi_glu":
self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
else:
self.mlp = nn.Sequential(nn.Linear(dim, int(dim * mlp_ratio)), QuickGELU() if activation == "quick_gelu" else nn.GELU(), nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
self.mlp = nn.Sequential(linear_cls(dim, int(dim * mlp_ratio)), QuickGELU() if activation == "quick_gelu" else nn.GELU(), linear_cls(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
def forward(self, x):
if self.post_norm:
......@@ -189,6 +199,8 @@ class VisionTransformer(nn.Module):
proj_dropout=0.0,
embedding_dropout=0.0,
norm_eps=1e-5,
quantized=False,
quant_scheme=None,
):
if image_size % patch_size != 0:
logger.info("[WARNING] image_size is not divisible by patch_size", flush=True)
......@@ -217,7 +229,9 @@ class VisionTransformer(nn.Module):
# transformer
self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
self.transformer = nn.Sequential(*[AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False, activation, attn_dropout, proj_dropout, norm_eps) for _ in range(num_layers)])
self.transformer = nn.Sequential(
*[AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False, activation, attn_dropout, proj_dropout, norm_eps, quantized, quant_scheme) for _ in range(num_layers)]
)
self.post_norm = LayerNorm(dim, eps=norm_eps)
# head
......@@ -252,28 +266,6 @@ class VisionTransformer(nn.Module):
return x
class XLMRobertaWithHead(XLMRoberta):
def __init__(self, **kwargs):
self.out_dim = kwargs.pop("out_dim")
super().__init__(**kwargs)
# head
mid_dim = (self.dim + self.out_dim) // 2
self.head = nn.Sequential(nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(), nn.Linear(mid_dim, self.out_dim, bias=False))
def forward(self, ids):
# xlm-roberta
x = super().forward(ids)
# average pooling
mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)
x = (x * mask).sum(dim=1) / mask.sum(dim=1)
# head
x = self.head(x)
return x
class XLMRobertaCLIP(nn.Module):
def __init__(
self,
......@@ -292,15 +284,12 @@ class XLMRobertaCLIP(nn.Module):
max_text_len=514,
type_size=1,
pad_id=1,
text_dim=1024,
text_heads=16,
text_layers=24,
text_post_norm=True,
text_dropout=0.1,
attn_dropout=0.0,
proj_dropout=0.0,
embedding_dropout=0.0,
norm_eps=1e-5,
quantized=False,
quant_scheme=None,
):
super().__init__()
self.embed_dim = embed_dim
......@@ -317,10 +306,6 @@ class XLMRobertaCLIP(nn.Module):
self.max_text_len = max_text_len
self.type_size = type_size
self.pad_id = pad_id
self.text_dim = text_dim
self.text_heads = text_heads
self.text_layers = text_layers
self.text_post_norm = text_post_norm
self.norm_eps = norm_eps
# models
......@@ -340,40 +325,11 @@ class XLMRobertaCLIP(nn.Module):
proj_dropout=proj_dropout,
embedding_dropout=embedding_dropout,
norm_eps=norm_eps,
)
self.textual = XLMRobertaWithHead(
vocab_size=vocab_size,
max_seq_len=max_text_len,
type_size=type_size,
pad_id=pad_id,
dim=text_dim,
out_dim=embed_dim,
num_heads=text_heads,
num_layers=text_layers,
post_norm=text_post_norm,
dropout=text_dropout,
quantized=quantized,
quant_scheme=quant_scheme,
)
self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
def forward(self, imgs, txt_ids):
"""
imgs: [B, 3, H, W] of torch.float32.
- mean: [0.48145466, 0.4578275, 0.40821073]
- std: [0.26862954, 0.26130258, 0.27577711]
txt_ids: [B, L] of torch.long.
Encoded by data.CLIPTokenizer.
"""
xi = self.visual(imgs)
xt = self.textual(txt_ids)
return xi, xt
def param_groups(self):
groups = [
{"params": [p for n, p in self.named_parameters() if "norm" in n or n.endswith("bias")], "weight_decay": 0.0},
{"params": [p for n, p in self.named_parameters() if not ("norm" in n or n.endswith("bias"))]},
]
return groups
def _clip(pretrained=False, pretrained_name=None, model_cls=XLMRobertaCLIP, return_transforms=False, return_tokenizer=False, tokenizer_padding="eos", dtype=torch.float32, device="cpu", **kwargs):
# init a model on device
......@@ -414,11 +370,6 @@ def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-r
max_text_len=514,
type_size=1,
pad_id=1,
text_dim=1024,
text_heads=16,
text_layers=24,
text_post_norm=True,
text_dropout=0.1,
attn_dropout=0.0,
proj_dropout=0.0,
embedding_dropout=0.0,
......@@ -428,20 +379,29 @@ def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-r
class CLIPModel:
def __init__(self, dtype, device, checkpoint_path, tokenizer_path):
def __init__(self, dtype, device, checkpoint_path, clip_quantized, clip_quantized_ckpt, quant_scheme):
self.dtype = dtype
self.device = device
self.checkpoint_path = checkpoint_path
self.tokenizer_path = tokenizer_path
self.quantized = clip_quantized
if self.quantized:
self.checkpoint_path = clip_quantized_ckpt
else:
self.checkpoint_path = checkpoint_path
logger.info(f"Loading weights from {self.checkpoint_path}")
# init model
self.model, self.transforms = clip_xlm_roberta_vit_h_14(pretrained=False, return_transforms=True, return_tokenizer=False, dtype=dtype, device=device)
self.model, self.transforms = clip_xlm_roberta_vit_h_14(
pretrained=False, return_transforms=True, return_tokenizer=False, dtype=dtype, device=device, quantized=self.quantized, quant_scheme=quant_scheme
)
self.model = self.model.eval().requires_grad_(False)
logging.info(f"loading {checkpoint_path}")
self.model.load_state_dict(torch.load(checkpoint_path, map_location="cpu", weights_only=True))
# init tokenizer
self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=self.model.max_text_len - 2, clean="whitespace")
weight_dict = torch.load(self.checkpoint_path, map_location="cpu", weights_only=True)
keys = list(weight_dict.keys())
for key in keys:
if "textual" in key:
weight_dict.pop(key)
self.model.load_state_dict(weight_dict)
def visual(self, videos, args):
if args.cpu_offload:
......
# Modified from transformers.models.xlm_roberta.modeling_xlm_roberta
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
__all__ = ["XLMRoberta", "xlm_roberta_large"]
class SelfAttention(nn.Module):
def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
assert dim % num_heads == 0
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.eps = eps
# layers
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask):
"""
x: [B, L, C].
"""
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
# compute query, key, value
q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
# compute attention
p = self.dropout.p if self.training else 0.0
x = F.scaled_dot_product_attention(q, k, v, mask, p)
x = x.permute(0, 2, 1, 3).reshape(b, s, c)
# output
x = self.o(x)
x = self.dropout(x)
return x
class AttentionBlock(nn.Module):
def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.post_norm = post_norm
self.eps = eps
# layers
self.attn = SelfAttention(dim, num_heads, dropout, eps)
self.norm1 = nn.LayerNorm(dim, eps=eps)
self.ffn = nn.Sequential(nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim), nn.Dropout(dropout))
self.norm2 = nn.LayerNorm(dim, eps=eps)
def forward(self, x, mask):
if self.post_norm:
x = self.norm1(x + self.attn(x, mask))
x = self.norm2(x + self.ffn(x))
else:
x = x + self.attn(self.norm1(x), mask)
x = x + self.ffn(self.norm2(x))
return x
class XLMRoberta(nn.Module):
"""
XLMRobertaModel with no pooler and no LM head.
"""
def __init__(self, vocab_size=250002, max_seq_len=514, type_size=1, pad_id=1, dim=1024, num_heads=16, num_layers=24, post_norm=True, dropout=0.1, eps=1e-5):
super().__init__()
self.vocab_size = vocab_size
self.max_seq_len = max_seq_len
self.type_size = type_size
self.pad_id = pad_id
self.dim = dim
self.num_heads = num_heads
self.num_layers = num_layers
self.post_norm = post_norm
self.eps = eps
# embeddings
self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id)
self.type_embedding = nn.Embedding(type_size, dim)
self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id)
self.dropout = nn.Dropout(dropout)
# blocks
self.blocks = nn.ModuleList([AttentionBlock(dim, num_heads, post_norm, dropout, eps) for _ in range(num_layers)])
# norm layer
self.norm = nn.LayerNorm(dim, eps=eps)
def forward(self, ids):
"""
ids: [B, L] of torch.LongTensor.
"""
b, s = ids.shape
mask = ids.ne(self.pad_id).long()
# embeddings
x = self.token_embedding(ids) + self.type_embedding(torch.zeros_like(ids)) + self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)
if self.post_norm:
x = self.norm(x)
x = self.dropout(x)
# blocks
mask = torch.where(mask.view(b, 1, 1, s).gt(0), 0.0, torch.finfo(x.dtype).min)
for block in self.blocks:
x = block(x, mask)
# output
if not self.post_norm:
x = self.norm(x)
return x
def xlm_roberta_large(pretrained=False, return_tokenizer=False, device="cpu", **kwargs):
"""
XLMRobertaLarge adapted from Huggingface.
"""
# params
cfg = dict(vocab_size=250002, max_seq_len=514, type_size=1, pad_id=1, dim=1024, num_heads=16, num_layers=24, post_norm=True, dropout=0.1, eps=1e-5)
cfg.update(**kwargs)
# init a model on device
with torch.device(device):
model = XLMRoberta(**cfg)
return model
import os
import sys
import torch
import json
from lightx2v.models.networks.hunyuan.weights.pre_weights import HunyuanPreWeights
from lightx2v.models.networks.hunyuan.weights.post_weights import HunyuanPostWeights
from lightx2v.models.networks.hunyuan.weights.transformer_weights import HunyuanTransformerWeights
......@@ -13,6 +13,7 @@ import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap
import lightx2v.attentions.distributed.ring.wrap as ring_dist_wrap
from lightx2v.utils.envs import *
from loguru import logger
from safetensors import safe_open
class HunyuanModel:
......@@ -25,13 +26,15 @@ class HunyuanModel:
self.config = config
self.device = device
self.args = args
self.dit_quantized = self.config.mm_config.get("mm_type", "Default") != "Default"
self.dit_quantized_ckpt = self.config.get("dit_quantized_ckpt", None)
self.weight_auto_quant = self.config.mm_config.get("weight_auto_quant", False)
if self.dit_quantized:
assert self.weight_auto_quant or self.dit_quantized_ckpt is not None
self._init_infer_class()
self._init_weights()
if GET_RUNNING_FLAG() == "save_naive_quant":
assert self.config.get("quant_model_path") is not None, "quant_model_path is None"
self.save_weights(self.config.quant_model_path)
sys.exit(0)
self._init_infer()
if config["parallel_attn_type"]:
......@@ -45,18 +48,6 @@ class HunyuanModel:
if self.config["cpu_offload"]:
self.to_cpu()
def _init_infer_class(self):
self.pre_infer_class = HunyuanPreInfer
self.post_infer_class = HunyuanPostInfer
if self.config["feature_caching"] == "NoCaching":
self.transformer_infer_class = HunyuanTransformerInfer
elif self.config["feature_caching"] == "TaylorSeer":
self.transformer_infer_class = HunyuanTransformerInferTaylorCaching
elif self.config["feature_caching"] == "Tea":
self.transformer_infer_class = HunyuanTransformerInferTeaCaching
else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
def _load_ckpt(self):
if self.args.task == "t2v":
ckpt_path = os.path.join(self.model_path, "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt")
......@@ -65,18 +56,41 @@ class HunyuanModel:
weight_dict = torch.load(ckpt_path, map_location=self.device, weights_only=True)["module"]
return weight_dict
def _load_ckpt_quant_model(self):
assert self.config.get("quant_model_path") is not None, "quant_model_path is None"
logger.info(f"Loading quant model from {self.config.quant_model_path}")
quant_weights_path = os.path.join(self.config.quant_model_path, "quant_weights.pth")
weight_dict = torch.load(quant_weights_path, map_location=self.device, weights_only=True)
def _load_quant_ckpt(self):
ckpt_path = self.config.dit_quantized_ckpt
logger.info(f"Loading quant dit model from {ckpt_path}")
if ckpt_path.endswith(".pth"):
logger.info(f"Loading {ckpt_path} as PyTorch model.")
weight_dict = torch.load(ckpt_path, map_location=self.device, weights_only=True)
else:
index_files = [f for f in os.listdir(ckpt_path) if f.endswith(".index.json")]
if not index_files:
raise FileNotFoundError(f"No .pth file or *.index.json found in {ckpt_path}")
index_path = os.path.join(ckpt_path, index_files[0])
logger.info(f" Using safetensors index: {index_path}")
with open(index_path, "r") as f:
index_data = json.load(f)
weight_dict = {}
for filename in set(index_data["weight_map"].values()):
safetensor_path = os.path.join(ckpt_path, filename)
with safe_open(safetensor_path, framework="pt", device=str(self.device)) as f:
logger.info(f"Loading weights from {safetensor_path}")
for k in f.keys():
weight_dict[k] = f.get_tensor(k)
if weight_dict[k].dtype == torch.float:
weight_dict[k] = weight_dict[k].to(torch.bfloat16)
return weight_dict
def _init_weights(self):
if GET_RUNNING_FLAG() == "save_naive_quant" or self.config["mm_config"].get("weight_auto_quant", False) or self.config["mm_config"].get("mm_type", "Default") == "Default":
if not self.dit_quantized or self.weight_auto_quant:
weight_dict = self._load_ckpt()
else:
weight_dict = self._load_ckpt_quant_model()
weight_dict = self._load_quant_ckpt()
# init weights
self.pre_weight = self.pre_weight_class(self.config)
self.post_weight = self.post_weight_class(self.config)
......@@ -146,3 +160,15 @@ class HunyuanModel:
self.scheduler.cnt += 1
if self.scheduler.cnt == self.scheduler.num_steps:
self.scheduler.cnt = 0
def _init_infer_class(self):
self.pre_infer_class = HunyuanPreInfer
self.post_infer_class = HunyuanPostInfer
if self.config["feature_caching"] == "NoCaching":
self.transformer_infer_class = HunyuanTransformerInfer
elif self.config["feature_caching"] == "TaylorSeer":
self.transformer_infer_class = HunyuanTransformerInferTaylorCaching
elif self.config["feature_caching"] == "Tea":
self.transformer_infer_class = HunyuanTransformerInferTeaCaching
else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
import torch
from .utils import compute_freqs, compute_freqs_dist, apply_rotary_emb
from lightx2v.common.offload.manager import WeightAsyncStreamManager
from lightx2v.common.offload.manager import WeightAsyncStreamManager, LazyWeightAsyncStreamManager
from lightx2v.utils.envs import *
......@@ -24,8 +24,25 @@ class WanTransformerInfer:
if offload_granularity == "block":
self.infer_func = self._infer_with_offload
elif offload_granularity == "phase":
self.infer_func = self._infer_with_phases_offload
self.weights_stream_mgr = WeightAsyncStreamManager(blocks_num=self.blocks_num, offload_ratio=offload_ratio, phases_num=self.phases_num)
if not self.config.get("lazy_load", False):
self.infer_func = self._infer_with_phases_offload
else:
self.infer_func = self._infer_with_phases_lazy_offload
if not self.config.get("lazy_load", False):
self.weights_stream_mgr = WeightAsyncStreamManager(
blocks_num=self.blocks_num,
offload_ratio=offload_ratio,
phases_num=self.phases_num,
)
else:
self.weights_stream_mgr = LazyWeightAsyncStreamManager(
blocks_num=self.blocks_num,
offload_ratio=offload_ratio,
phases_num=self.phases_num,
num_disk_workers=self.config.get("num_disk_workers", 2),
max_memory=self.config.get("max_memory", 2),
)
else:
self.infer_func = self._infer_without_offload
......@@ -33,10 +50,7 @@ class WanTransformerInfer:
self.scheduler = scheduler
def _calculate_q_k_len(self, q, k_lens):
# Handle query and key lengths (use `q_lens` and `k_lens` or set them to Lq and Lk if None)
q_lens = torch.tensor([q.size(0)], dtype=torch.int32, device=q.device)
# We don't have a batch dimension anymore, so directly use the `q_lens` and `k_lens` values
cu_seqlens_q = torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32)
cu_seqlens_k = torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32)
return cu_seqlens_q, cu_seqlens_k
......@@ -45,6 +59,7 @@ class WanTransformerInfer:
def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
return self.infer_func(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
# bug
def _infer_with_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
for block_idx in range(self.blocks_num):
if block_idx == 0:
......@@ -107,15 +122,77 @@ class WanTransformerInfer:
elif cur_phase_idx == 2:
x = self._infer_ffn(cur_phase, x, c_shift_msa, c_scale_msa, c_gate_msa)
is_last_phase = block_idx == weights.blocks_num - 1 and phase_idx == 2
is_last_phase = block_idx == weights.blocks_num - 1 and phase_idx == self.phases_num - 1
if not is_last_phase:
next_block_idx = block_idx + 1 if cur_phase_idx == 2 else block_idx
next_phase_idx = (cur_phase_idx + 1) % 3
next_block_idx = block_idx + 1 if phase_idx == self.phases_num - 1 else block_idx
next_phase_idx = (phase_idx + 1) % self.phases_num
self.weights_stream_mgr.prefetch_phase(next_block_idx, next_phase_idx, weights.blocks)
self.weights_stream_mgr.swap_phases()
weights.blocks[block_idx].modulation.to_cpu()
torch.cuda.empty_cache()
return x
def _infer_with_phases_lazy_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
self.weights_stream_mgr.prefetch_weights_from_disk(weights)
self.weights_stream_mgr._async_prefetch_block(weights)
for block_idx in range(weights.blocks_num):
with torch.cuda.stream(self.weights_stream_mgr.compute_stream):
weights.blocks[block_idx].modulation.to_cuda()
if embed0.dim() == 3:
modulation = weights.blocks[block_idx].modulation.tensor.unsqueeze(2)
current_embed0 = (modulation + embed0).chunk(6, dim=1)
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = [ei.squeeze(1) for ei in current_embed0]
elif embed0.dim() == 2:
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (weights.blocks[block_idx].modulation.tensor + embed0).chunk(6, dim=1)
for phase_idx in range(self.weights_stream_mgr.phases_num):
if block_idx == 0 and phase_idx == 0:
obj_key = (block_idx, phase_idx)
phase = self.weights_stream_mgr.pin_memory_buffer.get(obj_key)
phase.to_cuda()
self.weights_stream_mgr.active_weights[0] = (obj_key, phase)
with torch.cuda.stream(self.weights_stream_mgr.compute_stream):
(
(
_,
cur_phase_idx,
),
cur_phase,
) = self.weights_stream_mgr.active_weights[0]
if cur_phase_idx == 0:
x = self._infer_self_attn(
cur_phase,
x,
shift_msa,
scale_msa,
gate_msa,
grid_sizes,
freqs,
seq_lens,
)
elif cur_phase_idx == 1:
x = self._infer_cross_attn(cur_phase, x, context)
elif cur_phase_idx == 2:
x = self._infer_ffn(cur_phase, x, c_shift_msa, c_scale_msa, c_gate_msa)
if not (block_idx == weights.blocks_num - 1 and phase_idx == self.phases_num - 1):
next_block_idx = block_idx + 1 if phase_idx == self.phases_num - 1 else block_idx
next_phase_idx = (phase_idx + 1) % self.weights_stream_mgr.phases_num
self.weights_stream_mgr.prefetch_phase(next_block_idx, next_phase_idx, weights.blocks)
self.weights_stream_mgr.swap_phases()
weights.blocks[block_idx].modulation.to_cpu()
self.weights_stream_mgr._async_prefetch_block(weights)
torch.cuda.empty_cache()
......
......@@ -31,14 +31,16 @@ class WanModel:
def __init__(self, model_path, config, device):
self.model_path = model_path
self.config = config
self.dit_quantized = self.config.mm_config.get("mm_type", "Default") != "Default"
self.dit_quantized_ckpt = self.config.get("dit_quantized_ckpt", None)
self.weight_auto_quant = self.config.mm_config.get("weight_auto_quant", False)
if self.dit_quantized:
assert self.weight_auto_quant or self.dit_quantized_ckpt is not None
self.device = device
self._init_infer_class()
self._init_weights()
if GET_RUNNING_FLAG() == "save_naive_quant":
assert self.config.get("quant_model_path") is not None, "quant_model_path is None"
self.save_weights(self.config.quant_model_path)
sys.exit(0)
self._init_infer()
self.current_lora = None
......@@ -64,9 +66,9 @@ class WanModel:
use_bfloat16 = self.config.get("use_bfloat16", True)
with safe_open(file_path, framework="pt") as f:
if use_bfloat16:
tensor_dict = {key: f.get_tensor(key).to(torch.bfloat16).pin_memory().to(self.device) for key in f.keys()}
tensor_dict = {key: f.get_tensor(key).to(torch.bfloat16).to(self.device) for key in f.keys()}
else:
tensor_dict = {key: f.get_tensor(key).pin_memory().to(self.device) for key in f.keys()}
tensor_dict = {key: f.get_tensor(key).to(self.device) for key in f.keys()}
return tensor_dict
def _load_ckpt(self):
......@@ -82,22 +84,19 @@ class WanModel:
return weight_dict
def _load_quant_ckpt(self):
assert self.config.get("quant_model_path") is not None, "quant_model_path is None"
ckpt_path = self.config.quant_model_path
logger.info(f"Loading quant model from {ckpt_path}")
ckpt_path = self.config.dit_quantized_ckpt
logger.info(f"Loading quant dit model from {ckpt_path}")
quant_pth_file = os.path.join(ckpt_path, "quant_weights.pth")
if os.path.exists(quant_pth_file):
logger.info("Found quant_weights.pth, loading as PyTorch model.")
weight_dict = torch.load(quant_pth_file, map_location=self.device, weights_only=True)
if ckpt_path.endswith(".pth"):
logger.info(f"Loading {ckpt_path} as PyTorch model.")
weight_dict = torch.load(ckpt_path, map_location=self.device, weights_only=True)
else:
index_files = [f for f in os.listdir(ckpt_path) if f.endswith(".index.json")]
if not index_files:
raise FileNotFoundError(f"No quant_weights.pth or *.index.json found in {ckpt_path}")
raise FileNotFoundError(f"No .pth file or *.index.json found in {ckpt_path}")
index_path = os.path.join(ckpt_path, index_files[0])
logger.info(f"quant_weights.pth not found. Using safetensors index: {index_path}")
logger.info(f" Using safetensors index: {index_path}")
with open(index_path, "r") as f:
index_data = json.load(f)
......@@ -114,14 +113,48 @@ class WanModel:
return weight_dict
def _load_quant_split_ckpt(self):
lazy_load_model_path = self.config.dit_quantized_ckpt
logger.info(f"Loading splited quant model from {lazy_load_model_path}")
pre_post_weight_dict, transformer_weight_dict = {}, {}
safetensor_path = os.path.join(lazy_load_model_path, "non_block.safetensors")
with safe_open(safetensor_path, framework="pt", device=str(self.device)) as f:
for k in f.keys():
pre_post_weight_dict[k] = f.get_tensor(k)
if pre_post_weight_dict[k].dtype == torch.float:
pre_post_weight_dict[k] = pre_post_weight_dict[k].to(torch.bfloat16)
safetensors_pattern = os.path.join(lazy_load_model_path, "block_*.safetensors")
safetensors_files = glob.glob(safetensors_pattern)
if not safetensors_files:
raise FileNotFoundError(f"No .safetensors files found in directory: {lazy_load_model_path}")
for file_path in safetensors_files:
with safe_open(file_path, framework="pt") as f:
for k in f.keys():
if "modulation" in k:
transformer_weight_dict[k] = f.get_tensor(k)
if transformer_weight_dict[k].dtype == torch.float:
transformer_weight_dict[k] = transformer_weight_dict[k].to(torch.bfloat16)
return pre_post_weight_dict, transformer_weight_dict
def _init_weights(self, weight_dict=None):
if weight_dict is None:
if GET_RUNNING_FLAG() == "save_naive_quant" or self.config["mm_config"].get("weight_auto_quant", False) or self.config["mm_config"].get("mm_type", "Default") == "Default":
if not self.dit_quantized or self.weight_auto_quant:
self.original_weight_dict = self._load_ckpt()
else:
self.original_weight_dict = self._load_quant_ckpt()
if not self.config.get("lazy_load", False):
self.original_weight_dict = self._load_quant_ckpt()
else:
(
self.original_weight_dict,
self.transformer_weight_dict,
) = self._load_quant_split_ckpt()
else:
self.original_weight_dict = weight_dict
# init weights
self.pre_weight = self.pre_weight_class(self.config)
self.post_weight = self.post_weight_class(self.config)
......@@ -129,35 +162,16 @@ class WanModel:
# load weights
self.pre_weight.load(self.original_weight_dict)
self.post_weight.load(self.original_weight_dict)
self.transformer_weights.load(self.original_weight_dict)
if hasattr(self, "transformer_weight_dict"):
self.transformer_weights.load(self.transformer_weight_dict)
else:
self.transformer_weights.load(self.original_weight_dict)
def _init_infer(self):
self.pre_infer = self.pre_infer_class(self.config)
self.post_infer = self.post_infer_class(self.config)
self.transformer_infer = self.transformer_infer_class(self.config)
def save_weights(self, save_path):
if not os.path.exists(save_path):
os.makedirs(save_path)
pre_state_dict = self.pre_weight.state_dict()
logger.info(pre_state_dict.keys())
post_state_dict = self.post_weight.state_dict()
logger.info(post_state_dict.keys())
transformer_state_dict = self.transformer_weights.state_dict()
logger.info(transformer_state_dict.keys())
save_dict = {}
save_dict.update(pre_state_dict)
save_dict.update(post_state_dict)
save_dict.update(transformer_state_dict)
save_path = os.path.join(save_path, "quant_weights.pth")
torch.save(save_dict, save_path)
logger.info(f"Save weights to {save_path}")
def set_scheduler(self, scheduler):
self.scheduler = scheduler
self.pre_infer.set_scheduler(scheduler)
......
import torch
import os
from lightx2v.utils.registry_factory import (
MM_WEIGHT_REGISTER,
LN_WEIGHT_REGISTER,
......@@ -7,6 +8,7 @@ from lightx2v.utils.registry_factory import (
ATTN_WEIGHT_REGISTER,
)
from lightx2v.common.modules.weight_module import WeightModule, WeightModuleList
from safetensors import safe_open
class WanTransformerWeights(WeightModule):
......@@ -36,21 +38,28 @@ class WanTransformerAttentionBlock(WeightModule):
"modulation",
TENSOR_REGISTER["Default"](f"blocks.{self.block_index}.modulation"),
)
self.lazy_load = self.config.get("lazy_load", False)
if self.lazy_load:
lazy_load_path = os.path.join(self.config.dit_quantized_ckpt, f"block_{block_index}.safetensors")
self.lazy_load_file = safe_open(lazy_load_path, framework="pt", device="cpu")
else:
self.lazy_load_file = None
self.compute_phases = WeightModuleList(
[
WanSelfAttention(block_index, task, mm_type, config),
WanCrossAttention(block_index, task, mm_type, config),
WanFFN(block_index, task, mm_type, config),
WanSelfAttention(block_index, task, mm_type, config, self.lazy_load, self.lazy_load_file),
WanCrossAttention(block_index, task, mm_type, config, self.lazy_load, self.lazy_load_file),
WanFFN(block_index, task, mm_type, config, self.lazy_load, self.lazy_load_file),
]
)
self.add_module("compute_phases", self.compute_phases)
# i2v
class WanSelfAttention(WeightModule):
def __init__(self, block_index, task, mm_type, config):
def __init__(self, block_index, task, mm_type, config, lazy_load, lazy_load_file):
super().__init__()
self.block_index = block_index
self.mm_type = mm_type
......@@ -59,11 +68,16 @@ class WanSelfAttention(WeightModule):
self.quant_method = config["mm_config"].get("quant_method", None)
self.sparge = config.get("sparge", False)
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
self.add_module(
"self_attn_q",
MM_WEIGHT_REGISTER[self.mm_type](
f"blocks.{self.block_index}.self_attn.q.weight",
f"blocks.{self.block_index}.self_attn.q.bias",
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
......@@ -71,6 +85,8 @@ class WanSelfAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"blocks.{self.block_index}.self_attn.k.weight",
f"blocks.{self.block_index}.self_attn.k.bias",
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
......@@ -78,6 +94,8 @@ class WanSelfAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"blocks.{self.block_index}.self_attn.v.weight",
f"blocks.{self.block_index}.self_attn.v.bias",
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
......@@ -85,15 +103,25 @@ class WanSelfAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"blocks.{self.block_index}.self_attn.o.weight",
f"blocks.{self.block_index}.self_attn.o.bias",
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
"self_attn_norm_q",
RMS_WEIGHT_REGISTER["sgl-kernel"](f"blocks.{self.block_index}.self_attn.norm_q.weight"),
RMS_WEIGHT_REGISTER["sgl-kernel"](
f"blocks.{self.block_index}.self_attn.norm_q.weight",
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
"self_attn_norm_k",
RMS_WEIGHT_REGISTER["sgl-kernel"](f"blocks.{self.block_index}.self_attn.norm_k.weight"),
RMS_WEIGHT_REGISTER["sgl-kernel"](
f"blocks.{self.block_index}.self_attn.norm_k.weight",
self.lazy_load,
self.lazy_load_file,
),
)
if self.sparge:
assert self.config["sparge_ckpt"], "sparge_ckpt must be set when sparge is True"
......@@ -108,27 +136,39 @@ class WanSelfAttention(WeightModule):
if self.quant_method in ["smoothquant", "awq"]:
self.register_parameter(
"smooth_norm1_weight",
TENSOR_REGISTER["Default"](f"blocks.{self.block_index}.affine_norm1.weight"),
TENSOR_REGISTER["Default"](
f"blocks.{self.block_index}.affine_norm1.weight",
self.lazy_load,
self.lazy_load_file,
),
)
self.register_parameter(
"smooth_norm1_bias",
TENSOR_REGISTER["Default"](f"blocks.{self.block_index}.affine_norm1.bias"),
TENSOR_REGISTER["Default"](
f"blocks.{self.block_index}.affine_norm1.bias",
self.lazy_load,
self.lazy_load_file,
),
)
class WanCrossAttention(WeightModule):
def __init__(self, block_index, task, mm_type, config):
def __init__(self, block_index, task, mm_type, config, lazy_load, lazy_load_file):
super().__init__()
self.block_index = block_index
self.mm_type = mm_type
self.task = task
self.config = config
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
self.add_module(
"norm3",
LN_WEIGHT_REGISTER["Default"](
f"blocks.{self.block_index}.norm3.weight",
f"blocks.{self.block_index}.norm3.bias",
self.lazy_load,
self.lazy_load_file,
eps=1e-6,
),
)
......@@ -137,6 +177,8 @@ class WanCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"blocks.{self.block_index}.cross_attn.q.weight",
f"blocks.{self.block_index}.cross_attn.q.bias",
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
......@@ -144,6 +186,8 @@ class WanCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"blocks.{self.block_index}.cross_attn.k.weight",
f"blocks.{self.block_index}.cross_attn.k.bias",
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
......@@ -151,6 +195,8 @@ class WanCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"blocks.{self.block_index}.cross_attn.v.weight",
f"blocks.{self.block_index}.cross_attn.v.bias",
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
......@@ -158,15 +204,25 @@ class WanCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"blocks.{self.block_index}.cross_attn.o.weight",
f"blocks.{self.block_index}.cross_attn.o.bias",
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
"cross_attn_norm_q",
RMS_WEIGHT_REGISTER["sgl-kernel"](f"blocks.{self.block_index}.cross_attn.norm_q.weight"),
RMS_WEIGHT_REGISTER["sgl-kernel"](
f"blocks.{self.block_index}.cross_attn.norm_q.weight",
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
"cross_attn_norm_k",
RMS_WEIGHT_REGISTER["sgl-kernel"](f"blocks.{self.block_index}.cross_attn.norm_k.weight"),
RMS_WEIGHT_REGISTER["sgl-kernel"](
f"blocks.{self.block_index}.cross_attn.norm_k.weight",
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module("cross_attn_1", ATTN_WEIGHT_REGISTER[self.config["attention_type"]]())
......@@ -176,6 +232,8 @@ class WanCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"blocks.{self.block_index}.cross_attn.k_img.weight",
f"blocks.{self.block_index}.cross_attn.k_img.bias",
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
......@@ -183,29 +241,39 @@ class WanCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"blocks.{self.block_index}.cross_attn.v_img.weight",
f"blocks.{self.block_index}.cross_attn.v_img.bias",
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
"cross_attn_norm_k_img",
RMS_WEIGHT_REGISTER["sgl-kernel"](f"blocks.{self.block_index}.cross_attn.norm_k_img.weight"),
RMS_WEIGHT_REGISTER["sgl-kernel"](
f"blocks.{self.block_index}.cross_attn.norm_k_img.weight",
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module("cross_attn_2", ATTN_WEIGHT_REGISTER[self.config["attention_type"]]())
class WanFFN(WeightModule):
def __init__(self, block_index, task, mm_type, config):
def __init__(self, block_index, task, mm_type, config, lazy_load, lazy_load_file):
super().__init__()
self.block_index = block_index
self.mm_type = mm_type
self.task = task
self.config = config
self.quant_method = config["mm_config"].get("quant_method", None)
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
self.add_module(
"ffn_0",
MM_WEIGHT_REGISTER[self.mm_type](
f"blocks.{self.block_index}.ffn.0.weight",
f"blocks.{self.block_index}.ffn.0.bias",
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
......@@ -213,15 +281,25 @@ class WanFFN(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"blocks.{self.block_index}.ffn.2.weight",
f"blocks.{self.block_index}.ffn.2.bias",
self.lazy_load,
self.lazy_load_file,
),
)
if self.quant_method in ["smoothquant", "awq"]:
self.register_parameter(
"smooth_norm2_weight",
TENSOR_REGISTER["Default"](f"blocks.{self.block_index}.affine_norm2.weight"),
TENSOR_REGISTER["Default"](
f"blocks.{self.block_index}.affine_norm3.weight",
self.lazy_load,
self.lazy_load_file,
),
)
self.register_parameter(
"smooth_norm2_bias",
TENSOR_REGISTER["Default"](f"blocks.{self.block_index}.affine_norm2.bias"),
TENSOR_REGISTER["Default"](
f"blocks.{self.block_index}.affine_norm3.bias",
self.lazy_load,
self.lazy_load_file,
),
)
......@@ -26,6 +26,7 @@ class DefaultRunner:
logger.warning("No prompt enhancer server available, disable prompt enhancer.")
def init_modules(self):
self.set_init_device()
if self.config["mode"] == "split_server":
self.tensor_transporter = TensorTransporter()
self.image_transporter = ImageTransporter()
......@@ -45,7 +46,8 @@ class DefaultRunner:
else:
self.run_input_encoder = self.run_input_encoder_server_t2v
else:
self.load_model()
if not self.config.get("lazy_load", False):
self.load_model()
self.run_dit = self.run_dit_local
self.run_vae_decoder = self.run_vae_decoder_local
if self.config["task"] == "i2v":
......@@ -53,23 +55,21 @@ class DefaultRunner:
else:
self.run_input_encoder = self.run_input_encoder_local_t2v
def get_init_device(self):
def set_init_device(self):
if self.config["parallel_attn_type"]:
cur_rank = dist.get_rank()
torch.cuda.set_device(cur_rank)
if self.config.cpu_offload:
init_device = torch.device("cpu")
self.init_device = torch.device("cpu")
else:
init_device = torch.device("cuda")
return init_device
self.init_device = torch.device("cuda")
@ProfilingContext("Load models")
def load_model(self):
init_device = self.get_init_device()
self.text_encoders = self.load_text_encoder(init_device)
self.model = self.load_transformer(init_device)
self.image_encoder = self.load_image_encoder(init_device)
self.vae_encoder, self.vae_decoder = self.load_vae(init_device)
self.model = self.load_transformer()
self.text_encoders = self.load_text_encoder()
self.image_encoder = self.load_image_encoder()
self.vae_encoder, self.vae_decoder = self.load_vae()
def check_sub_servers(self, task_type):
urls = self.config.get("sub_servers", {}).get(task_type, [])
......@@ -124,7 +124,10 @@ class DefaultRunner:
def end_run(self):
self.model.scheduler.clear()
del self.inputs, self.model.scheduler
if self.config.get("lazy_load", False):
del self.model
torch.cuda.empty_cache()
gc.collect()
@ProfilingContext("Run Encoders")
async def run_input_encoder_local_i2v(self):
......@@ -133,16 +136,22 @@ class DefaultRunner:
clip_encoder_out = self.run_image_encoder(img)
vae_encode_out, kwargs = self.run_vae_encoder(img)
text_encoder_output = self.run_text_encoder(prompt, img)
torch.cuda.empty_cache()
gc.collect()
return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output, img)
@ProfilingContext("Run Encoders")
async def run_input_encoder_local_t2v(self):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
text_encoder_output = self.run_text_encoder(prompt, None)
torch.cuda.empty_cache()
gc.collect()
return {"text_encoder_output": text_encoder_output, "image_encoder_output": None}
@ProfilingContext("Run DiT")
async def run_dit_local(self, kwargs):
if self.config.get("lazy_load", False):
self.model = self.load_transformer()
self.init_scheduler()
self.model.scheduler.prepare(self.inputs["image_encoder_output"])
latents, generator = self.run()
......@@ -151,7 +160,12 @@ class DefaultRunner:
@ProfilingContext("Run VAE Decoder")
async def run_vae_decoder_local(self, latents, generator):
if self.config.get("lazy_load", False):
self.vae_decoder = self.load_vae_decoder()
images = self.vae_decoder.decode(latents, generator=generator, config=self.config)
if self.config.get("lazy_load", False):
torch.cuda.empty_cache()
gc.collect()
return images
@ProfilingContext("Save video")
......@@ -228,12 +242,16 @@ class DefaultRunner:
n_prompt = self.config.get("negative_prompt", "")
img = Image.open(self.config["image_path"]).convert("RGB")
clip_encoder_out, vae_encode_out, text_encoder_output = await self.post_encoders_i2v(prompt, img, n_prompt)
torch.cuda.empty_cache()
gc.collect()
return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output, img)
async def run_input_encoder_server_t2v(self):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
n_prompt = self.config.get("negative_prompt", "")
text_encoder_output = await self.post_encoders_t2v(prompt, n_prompt)
torch.cuda.empty_cache()
gc.collect()
return {"text_encoder_output": text_encoder_output, "image_encoder_output": None}
async def run_dit_server(self, kwargs):
......@@ -265,5 +283,5 @@ class DefaultRunner:
images = await self.run_vae_decoder(latents, generator)
self.save_video(images)
del latents, generator, images
gc.collect()
torch.cuda.empty_cache()
gc.collect()
......@@ -21,23 +21,23 @@ class HunyuanRunner(DefaultRunner):
def __init__(self, config):
super().__init__(config)
def load_transformer(self, init_device):
return HunyuanModel(self.config.model_path, self.config, init_device, self.config)
def load_transformer(self):
return HunyuanModel(self.config.model_path, self.config, self.init_device, self.config)
def load_image_encoder(self, init_device):
def load_image_encoder(self):
return None
def load_text_encoder(self, init_device):
def load_text_encoder(self):
if self.config.task == "t2v":
text_encoder_1 = TextEncoderHFLlamaModel(os.path.join(self.config.model_path, "text_encoder"), init_device)
text_encoder_1 = TextEncoderHFLlamaModel(os.path.join(self.config.model_path, "text_encoder"), self.init_device)
else:
text_encoder_1 = TextEncoderHFLlavaModel(os.path.join(self.config.model_path, "text_encoder_i2v"), init_device)
text_encoder_2 = TextEncoderHFClipModel(os.path.join(self.config.model_path, "text_encoder_2"), init_device)
text_encoder_1 = TextEncoderHFLlavaModel(os.path.join(self.config.model_path, "text_encoder_i2v"), self.init_device)
text_encoder_2 = TextEncoderHFClipModel(os.path.join(self.config.model_path, "text_encoder_2"), self.init_device)
text_encoders = [text_encoder_1, text_encoder_2]
return text_encoders
def load_vae(self, init_device):
vae_model = VideoEncoderKLCausal3DModel(self.config.model_path, dtype=torch.float16, device=init_device, config=self.config)
def load_vae(self):
vae_model = VideoEncoderKLCausal3DModel(self.config.model_path, dtype=torch.float16, device=self.init_device, config=self.config)
return vae_model, vae_model
def init_scheduler(self):
......
import os
import gc
import numpy as np
import torch
import torchvision.transforms.functional as TF
......@@ -25,63 +26,83 @@ class WanRunner(DefaultRunner):
def __init__(self, config):
super().__init__(config)
def load_transformer(self, init_device):
model = WanModel(self.config.model_path, self.config, init_device)
def load_transformer(self):
model = WanModel(
self.config.model_path,
self.config,
self.init_device,
)
if self.config.lora_path:
assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False)
lora_wrapper = WanLoraWrapper(model)
lora_name = lora_wrapper.load_lora(self.config.lora_path)
lora_wrapper.apply_lora(lora_name, self.config.strength_model)
logger.info(f"Loaded LoRA: {lora_name}")
return model
def load_image_encoder(self, init_device):
def load_image_encoder(self):
image_encoder = None
if self.config.task == "i2v":
image_encoder = CLIPModel(
dtype=torch.float16,
device=init_device,
device=self.init_device,
checkpoint_path=os.path.join(
self.config.model_path,
"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
),
tokenizer_path=os.path.join(self.config.model_path, "xlm-roberta-large"),
clip_quantized=self.config.get("clip_quantized", False),
clip_quantized_ckpt=self.config.get("clip_quantized_ckpt", None),
quant_scheme=self.config.get("clip_quant_scheme", None),
)
return image_encoder
def load_text_encoder(self, init_device):
def load_text_encoder(self):
text_encoder = T5EncoderModel(
text_len=self.config["text_len"],
dtype=torch.bfloat16,
device=init_device,
device=self.init_device,
checkpoint_path=os.path.join(self.config.model_path, "models_t5_umt5-xxl-enc-bf16.pth"),
tokenizer_path=os.path.join(self.config.model_path, "google/umt5-xxl"),
shard_fn=None,
cpu_offload=self.config.cpu_offload,
offload_granularity=self.config.get("text_encoder_offload_granularity", "model"),
offload_granularity=self.config.get("t5_offload_granularity", "model"),
t5_quantized=self.config.get("t5_quantized", False),
t5_quantized_ckpt=self.config.get("t5_quantized_ckpt", None),
quant_scheme=self.config.get("t5_quant_scheme", None),
)
text_encoders = [text_encoder]
return text_encoders
def load_vae(self, init_device):
def load_vae_encoder(self):
vae_config = {
"vae_pth": os.path.join(self.config.model_path, "Wan2.1_VAE.pth"),
"device": init_device,
"device": self.init_device,
"parallel": self.config.parallel_vae,
"use_tiling": self.config.get("use_tiling_vae", False),
}
use_tiny_decoder = self.config.get("tiny_vae", False)
is_i2v = self.config.task == "i2v"
if use_tiny_decoder:
if self.config.task != "i2v":
return None
else:
return WanVAE(**vae_config)
def load_vae_decoder(self):
vae_config = {
"vae_pth": os.path.join(self.config.model_path, "Wan2.1_VAE.pth"),
"device": self.init_device,
"parallel": self.config.parallel_vae,
"use_tiling": self.config.get("use_tiling_vae", False),
}
if self.config.get("tiny_vae", False):
vae_decoder = WanVAE_tiny(
vae_pth=self.config.tiny_vae_path,
device=init_device,
device=self.init_device,
).to("cuda")
vae_encoder = WanVAE(**vae_config) if is_i2v else None
else:
vae_decoder = WanVAE(**vae_config)
vae_encoder = vae_decoder if is_i2v else None
return vae_decoder
return vae_encoder, vae_decoder
def load_vae(self):
return self.load_vae_encoder(), self.load_vae_decoder()
def init_scheduler(self):
if self.config.feature_caching == "NoCaching":
......@@ -93,17 +114,29 @@ class WanRunner(DefaultRunner):
self.model.set_scheduler(scheduler)
def run_text_encoder(self, text, img):
if self.config.get("lazy_load", False):
self.text_encoders = self.load_text_encoder()
text_encoder_output = {}
n_prompt = self.config.get("negative_prompt", "")
context = self.text_encoders[0].infer([text])
context_null = self.text_encoders[0].infer([n_prompt if n_prompt else ""])
if self.config.get("lazy_load", False):
del self.text_encoders[0]
torch.cuda.empty_cache()
gc.collect()
text_encoder_output["context"] = context
text_encoder_output["context_null"] = context_null
return text_encoder_output
def run_image_encoder(self, img):
if self.config.get("lazy_load", False):
self.image_encoder = self.load_image_encoder()
img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
clip_encoder_out = self.image_encoder.visual([img[:, None, :, :]], self.config).squeeze(0).to(torch.bfloat16)
if self.config.get("lazy_load", False):
del self.image_encoder
torch.cuda.empty_cache()
gc.collect()
return clip_encoder_out
def run_vae_encoder(self, img):
......@@ -120,11 +153,19 @@ class WanRunner(DefaultRunner):
self.config.lat_h, kwargs["lat_h"] = lat_h, lat_h
self.config.lat_w, kwargs["lat_w"] = lat_w, lat_w
msk = torch.ones(1, self.config.target_video_length, lat_h, lat_w, device=torch.device("cuda"))
msk = torch.ones(
1,
self.config.target_video_length,
lat_h,
lat_w,
device=torch.device("cuda"),
)
msk[:, 1:] = 0
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
msk = msk.transpose(1, 2)[0]
if self.config.get("lazy_load", False):
self.vae_encoder = self.load_vae_encoder()
vae_encode_out = self.vae_encoder.encode(
[
torch.concat(
......@@ -137,12 +178,22 @@ class WanRunner(DefaultRunner):
],
self.config,
)[0]
if self.config.get("lazy_load", False):
del self.vae_encoder
torch.cuda.empty_cache()
gc.collect()
vae_encode_out = torch.concat([msk, vae_encode_out]).to(torch.bfloat16)
return vae_encode_out, kwargs
def get_encoder_output_i2v(self, clip_encoder_out, vae_encode_out, text_encoder_output, img):
image_encoder_output = {"clip_encoder_out": clip_encoder_out, "vae_encode_out": vae_encode_out}
return {"text_encoder_output": text_encoder_output, "image_encoder_output": image_encoder_output}
image_encoder_output = {
"clip_encoder_out": clip_encoder_out,
"vae_encode_out": vae_encode_out,
}
return {
"text_encoder_output": text_encoder_output,
"image_encoder_output": image_encoder_output,
}
def set_target_shape(self):
ret = {}
......@@ -167,4 +218,11 @@ class WanRunner(DefaultRunner):
return ret
def save_video_func(self, images):
cache_video(tensor=images, save_file=self.config.save_video_path, fps=self.config.get("fps", 16), nrow=1, normalize=True, value_range=(-1, 1))
cache_video(
tensor=images,
save_file=self.config.save_video_path,
fps=self.config.get("fps", 16),
nrow=1,
normalize=True,
value_range=(-1, 1),
)
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
# =========================
# save quantization weight
# =========================
export RUNNING_FLAG=save_naive_quant
python -m lightx2v.infer \
--model_cls hunyuan \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/hunyuan_i2v_save_quant.json \
--prompt "An Asian man with short hair in black tactical uniform and white clothes waves a firework stick." \
--image_path ${lightx2v_path}/assets/inputs/imgs/img_1.jpg \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_hy_i2v.mp4
sleep 2
# =========================
# load quantization weight and inference
# =========================
export RUNNING_FLAG=infer
python -m lightx2v.infer \
--model_cls hunyuan \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/hunyuan_i2v_save_quant.json \
--prompt "An Asian man with short hair in black tactical uniform and white clothes waves a firework stick." \
--image_path ${lightx2v_path}/assets/inputs/imgs/img_1.jpg \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_hy_i2v.mp4
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
# =========================
# save quantization weight
# =========================
export RUNNING_FLAG=save_naive_quant
python -m lightx2v.infer \
--model_cls hunyuan \
--task t2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/hunyuan_t2v_save_quant.json \
--prompt "A cat walks on the grass, realistic style." \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_hy_t2v.mp4
sleep 2
# =========================
# load quantization weight and inference
# =========================
export RUNNING_FLAG=infer
python -m lightx2v.infer \
--model_cls hunyuan \
--task t2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/hunyuan_t2v_save_quant.json \
--prompt "A cat walks on the grass, realistic style." \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_hy_t2v.mp4
......@@ -7,7 +7,7 @@ model_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using defalt value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
......@@ -28,37 +28,12 @@ export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
# =========================
# save quantization weight
# =========================
export RUNNING_FLAG=save_naive_quant
python -m lightx2v.infer \
--model_cls wan2.1 \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/wan_i2v_save_quant.json \
--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \
--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v.mp4
sleep 2
# =========================
# load quantization weight and inference
# =========================
export RUNNING_FLAG=infer
python -m lightx2v.infer \
--model_cls wan2.1 \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/wan_i2v_save_quant.json \
--config_json ${lightx2v_path}/configs/quantization/wan_i2v_quant_auto.json \
--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \
--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v.mp4
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v_tea.mp4
......@@ -7,7 +7,7 @@ model_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using defalt value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
......@@ -28,35 +28,12 @@ export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
# =========================
# save quantization weight
# =========================
export RUNNING_FLAG=save_naive_quant
python -m lightx2v.infer \
--model_cls wan2.1 \
--task t2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/wan_t2v_save_quant.json \
--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \
--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v.mp4
sleep 2
# =========================
# load quantization weight and inference
# =========================
export RUNNING_FLAG=infer
python -m lightx2v.infer \
--model_cls wan2.1 \
--task t2v \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/wan_t2v_save_quant.json \
--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \
--config_json ${lightx2v_path}/configs/quantization/wan_i2v_quant_offline.json \
--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \
--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v.mp4
--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v_tea.mp4
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
# =========================
# save quantization weight
# =========================
export RUNNING_FLAG=save_naive_quant
python -m lightx2v.infer \
--model_cls wan2.1_causvid \
--task t2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/wan_t2v_causvid_save_quant.json \
--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \
--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v_causvid.mp4
sleep 2
# =========================
# load quantization weight and inference
# =========================
export RUNNING_FLAG=infer
python -m lightx2v.infer \
--model_cls wan2.1_causvid \
--task t2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/wan_t2v_causvid_save_quant.json \
--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \
--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v_causvid.mp4
This diff is collapsed.
# 模型转换工具
A powerful utility for converting model weights between different formats and performing quantization tasks.
## Diffusers
Facilitates mutual conversion between diffusers architecture and lightx2v architecture
### Lightx2v->Diffusers
```bash
python converter.py \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P \
--output /Path/To/Wan2.1-I2V-14B-480P-Diffusers \
--direction forward
```
### Diffusers->Lightx2v
```bash
python converter.py \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P-Diffusers \
--output /Path/To/Wan2.1-I2V-14B-480P \
--direction backward
```
## Quantization
This tool supports converting fp32/fp16/bf16 model weights to INT8、FP8 type.
### Wan DIT
```bash
python converter.py \
--quantized \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/ \
--output /Path/To/output \
--output_ext .pth\
--output_name wan_int8 \
--dtype torch.int8 \
--model_type wan_dit
```
```bash
python converter.py \
--quantized \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/ \
--output /Path/To/output \
--output_ext .pth\
--output_name wan_fp8 \
--dtype torch.float8_e4m3_fn \
--model_type wan_dit
```
### Hunyuan DIT
```bash
python converter.py \
--quantized \
--source /Path/To/hunyuan/lightx2v_format/i2v/ \
--output /Path/To/output \
--output_ext .pth\
--output_name hunyuan_int8 \
--dtype torch.int8 \
--model_type hunyuan_dit
```
```bash
python converter.py \
--quantized \
--source /Path/To/hunyuan/lightx2v_format/i2v/ \
--output /Path/To/output \
--output_ext .pth\
--output_name hunyuan_fp8 \
--dtype torch.float8_e4m3_fn \
--model_type hunyuan_dit
```
### Wan T5EncoderModel
```bash
python converter.py \
--quantized \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/models_t5_umt5-xxl-enc-bf16.pth \
--output /Path/To/output \
--output_ext .pth\
--output_name models_t5_umt5-xxl-enc-int8 \
--dtype torch.int8 \
--model_type wan_t5
```
```bash
python converter.py \
--quantized \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/models_t5_umt5-xxl-enc-bf16.pth \
--output /Path/To/output \
--output_ext .pth\
--output_name models_t5_umt5-xxl-enc-fp8 \
--dtype torch.float8_e4m3fn \
--model_type wan_t5
```
### Wan CLIPModel
```bash
python converter.py \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth \
--quantized \
--output /Path/To/output \
--output_ext .pth \
--output_name clip_int8.pth \
--dtype torch.int8 \
--model_type wan_clip
```
```bash
python converter.py \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth \
--quantized \
--output /Path/To/output \
--output_ext .pth \
--output_name clip_fp8.pth \
--dtype torch.float8_e4m3fn \
--model_type wan_clip
```
# 模型转换工具
一款功能强大的实用工具,可在不同格式之间转换模型权重并执行量化任务。
## Diffusers
支持 Diffusers 架构与 LightX2V 架构之间的相互转换
### Lightx2v->Diffusers
```bash
python converter.py \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P \
--output /Path/To/Wan2.1-I2V-14B-480P-Diffusers \
--direction forward
```
### Diffusers->Lightx2v
```bash
python converter.py \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P-Diffusers \
--output /Path/To/Wan2.1-I2V-14B-480P \
--direction backward
```
## 量化
该工具支持将 **FP32/FP16/BF16** 模型权重转换为 **INT8、FP8** 类型。
### Wan DIT
```bash
python converter.py \
--quantized \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/ \
--output /Path/To/output \
--output_ext .pth\
--output_name wan_int8 \
--dtype torch.int8 \
--model_type wan_dit
```
```bash
python converter.py \
--quantized \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/ \
--output /Path/To/output \
--output_ext .pth\
--output_name wan_fp8 \
--dtype torch.float8_e4m3_fn \
--model_type wan_dit
```
### Hunyuan DIT
```bash
python converter.py \
--quantized \
--source /Path/To/hunyuan/lightx2v_format/i2v/ \
--output /Path/To/output \
--output_ext .pth\
--output_name hunyuan_int8 \
--dtype torch.int8 \
--model_type hunyuan_dit
```
```bash
python converter.py \
--quantized \
--source /Path/To/hunyuan/lightx2v_format/i2v/ \
--output /Path/To/output \
--output_ext .pth\
--output_name hunyuan_fp8 \
--dtype torch.float8_e4m3_fn \
--model_type hunyuan_dit
```
### Wan T5EncoderModel
```bash
python converter.py \
--quantized \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/models_t5_umt5-xxl-enc-bf16.pth \
--output /Path/To/output \
--output_ext .pth\
--output_name models_t5_umt5-xxl-enc-int8 \
--dtype torch.int8 \
--model_type wan_t5
```
```bash
python converter.py \
--quantized \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/models_t5_umt5-xxl-enc-bf16.pth \
--output /Path/To/output \
--output_ext .pth\
--output_name models_t5_umt5-xxl-enc-fp8 \
--dtype torch.float8_e4m3fn \
--model_type wan_t5
```
### Wan CLIPModel
```bash
python converter.py \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth \
--quantized \
--output /Path/To/output \
--output_ext .pth \
--output_name clip_int8.pth \
--dtype torch.int8 \
--model_type wan_clip
```
```bash
python converter.py \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth \
--quantized \
--output /Path/To/output \
--output_ext .pth \
--output_name clip_fp8.pth \
--dtype torch.float8_e4m3fn \
--model_type wan_clip
```
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment