Commit 5c241f86 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Support run load memory machine, fix some bugs and reconstruct quantizaton. (#61)



* reconstruct quantization and fix memory leak bug.

* Support lazy load inference.

* reconstruct quantization

* Fix hunyuan bugs

* deleted tmp file

---------
Co-authored-by: default avatarroot <root@pt-c0b333b3a1834e81a0d4d5f412c6ffa1-worker-0.pt-c0b333b3a1834e81a0d4d5f412c6ffa1.ns-devsft-3460edd0.svc.cluster.local>
Co-authored-by: default avatargushiqiao <gushqiaio@sensetime.com>
Co-authored-by: default avatargushiqiao <gushiqiao@sensetime.com>
parent b7d2d43f
import torch
import torch.nn as nn
from vllm import _custom_ops as ops
try:
import q8_kernels.functional as Q8F
except ImportError:
Q8F = None
class QuantLinearInt8(nn.Module):
def __init__(self, in_features, out_features, bias=True):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.int8))
self.register_buffer("weight_scale", torch.empty((out_features, 1), dtype=torch.float32))
if bias:
self.register_buffer("bias", torch.empty(out_features, dtype=torch.float32))
else:
self.register_buffer("bias", None)
def act_quant_func(self, x):
input_tensor_quant, input_tensor_scale, _ = ops.scaled_int8_quant(x, scale=None, azp=None, symmetric=True)
return input_tensor_quant, input_tensor_scale
def forward(self, x):
input_tensor_quant, input_tensor_scale = self.act_quant_func(x)
output_tensor = Q8F.linear.q8_linear(
input_tensor_quant,
self.weight,
self.bias.float() if self.bias is not None else None,
input_tensor_scale,
self.weight_scale.float(),
fuse_gelu=False,
out_dtype=torch.bfloat16,
)
return output_tensor
...@@ -9,6 +9,8 @@ import torch.nn.functional as F ...@@ -9,6 +9,8 @@ import torch.nn.functional as F
from .tokenizer import HuggingfaceTokenizer from .tokenizer import HuggingfaceTokenizer
from loguru import logger from loguru import logger
from lightx2v.models.input_encoders.hf.q_linear import QuantLinearInt8
__all__ = [ __all__ = [
"T5Model", "T5Model",
...@@ -63,7 +65,7 @@ class T5LayerNorm(nn.Module): ...@@ -63,7 +65,7 @@ class T5LayerNorm(nn.Module):
class T5Attention(nn.Module): class T5Attention(nn.Module):
def __init__(self, dim, dim_attn, num_heads, dropout=0.1): def __init__(self, dim, dim_attn, num_heads, dropout=0.1, quantized=False, quant_scheme=None):
assert dim_attn % num_heads == 0 assert dim_attn % num_heads == 0
super(T5Attention, self).__init__() super(T5Attention, self).__init__()
self.dim = dim self.dim = dim
...@@ -71,11 +73,17 @@ class T5Attention(nn.Module): ...@@ -71,11 +73,17 @@ class T5Attention(nn.Module):
self.num_heads = num_heads self.num_heads = num_heads
self.head_dim = dim_attn // num_heads self.head_dim = dim_attn // num_heads
if quantized:
if quant_scheme == "int8":
linear_cls = QuantLinearInt8
else:
linear_cls = nn.Linear
# layers # layers
self.q = nn.Linear(dim, dim_attn, bias=False) self.q = linear_cls(dim, dim_attn, bias=False)
self.k = nn.Linear(dim, dim_attn, bias=False) self.k = linear_cls(dim, dim_attn, bias=False)
self.v = nn.Linear(dim, dim_attn, bias=False) self.v = linear_cls(dim, dim_attn, bias=False)
self.o = nn.Linear(dim_attn, dim, bias=False) self.o = linear_cls(dim_attn, dim, bias=False)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
def forward(self, x, context=None, mask=None, pos_bias=None): def forward(self, x, context=None, mask=None, pos_bias=None):
...@@ -104,7 +112,7 @@ class T5Attention(nn.Module): ...@@ -104,7 +112,7 @@ class T5Attention(nn.Module):
# compute attention (T5 does not use scaling) # compute attention (T5 does not use scaling)
attn = torch.einsum("binc,bjnc->bnij", q, k) + attn_bias attn = torch.einsum("binc,bjnc->bnij", q, k) + attn_bias
attn = F.softmax(attn.float(), dim=-1).type_as(attn) attn = F.softmax(attn.float(), dim=-1).to(torch.bfloat16)
x = torch.einsum("bnij,bjnc->binc", attn, v) x = torch.einsum("bnij,bjnc->binc", attn, v)
# output # output
...@@ -115,15 +123,20 @@ class T5Attention(nn.Module): ...@@ -115,15 +123,20 @@ class T5Attention(nn.Module):
class T5FeedForward(nn.Module): class T5FeedForward(nn.Module):
def __init__(self, dim, dim_ffn, dropout=0.1): def __init__(self, dim, dim_ffn, dropout=0.1, quantized=False, quant_scheme=None):
super(T5FeedForward, self).__init__() super(T5FeedForward, self).__init__()
self.dim = dim self.dim = dim
self.dim_ffn = dim_ffn self.dim_ffn = dim_ffn
if quantized:
if quant_scheme == "int8":
linear_cls = QuantLinearInt8
else:
linear_cls = nn.Linear
# layers # layers
self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU()) self.gate = nn.Sequential(linear_cls(dim, dim_ffn, bias=False), GELU())
self.fc1 = nn.Linear(dim, dim_ffn, bias=False) self.fc1 = linear_cls(dim, dim_ffn, bias=False)
self.fc2 = nn.Linear(dim_ffn, dim, bias=False) self.fc2 = linear_cls(dim_ffn, dim, bias=False)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
def forward(self, x): def forward(self, x):
...@@ -135,16 +148,7 @@ class T5FeedForward(nn.Module): ...@@ -135,16 +148,7 @@ class T5FeedForward(nn.Module):
class T5SelfAttention(nn.Module): class T5SelfAttention(nn.Module):
def __init__( def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.1, quantized=False, quant_scheme=None):
self,
dim,
dim_attn,
dim_ffn,
num_heads,
num_buckets,
shared_pos=True,
dropout=0.1,
):
super(T5SelfAttention, self).__init__() super(T5SelfAttention, self).__init__()
self.dim = dim self.dim = dim
self.dim_attn = dim_attn self.dim_attn = dim_attn
...@@ -155,9 +159,9 @@ class T5SelfAttention(nn.Module): ...@@ -155,9 +159,9 @@ class T5SelfAttention(nn.Module):
# layers # layers
self.norm1 = T5LayerNorm(dim) self.norm1 = T5LayerNorm(dim)
self.attn = T5Attention(dim, dim_attn, num_heads, dropout) self.attn = T5Attention(dim, dim_attn, num_heads, dropout, quantized, quant_scheme)
self.norm2 = T5LayerNorm(dim) self.norm2 = T5LayerNorm(dim)
self.ffn = T5FeedForward(dim, dim_ffn, dropout) self.ffn = T5FeedForward(dim, dim_ffn, dropout, quantized, quant_scheme)
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True) self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
def forward(self, x, mask=None, pos_bias=None): def forward(self, x, mask=None, pos_bias=None):
...@@ -244,20 +248,9 @@ class T5RelativeEmbedding(nn.Module): ...@@ -244,20 +248,9 @@ class T5RelativeEmbedding(nn.Module):
class T5Encoder(nn.Module): class T5Encoder(nn.Module):
def __init__( def __init__(self, vocab, dim, dim_attn, dim_ffn, num_heads, num_layers, num_buckets, shared_pos=True, dropout=0.1, cpu_offload=False, quantized=False, quant_scheme=None):
self,
vocab,
dim,
dim_attn,
dim_ffn,
num_heads,
num_layers,
num_buckets,
shared_pos=True,
dropout=0.1,
cpu_offload=False,
):
super(T5Encoder, self).__init__() super(T5Encoder, self).__init__()
self.cpu_offload = cpu_offload self.cpu_offload = cpu_offload
self.dim = dim self.dim = dim
self.dim_attn = dim_attn self.dim_attn = dim_attn
...@@ -266,16 +259,17 @@ class T5Encoder(nn.Module): ...@@ -266,16 +259,17 @@ class T5Encoder(nn.Module):
self.num_layers = num_layers self.num_layers = num_layers
self.num_buckets = num_buckets self.num_buckets = num_buckets
self.shared_pos = shared_pos self.shared_pos = shared_pos
self.quant_scheme = quant_scheme
# layers # layers
self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim) self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True) if shared_pos else None self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True) if shared_pos else None
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.blocks = nn.ModuleList([T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout) for _ in range(num_layers)]) self.blocks = nn.ModuleList([T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout, quantized, quant_scheme) for _ in range(num_layers)])
self.norm = T5LayerNorm(dim) self.norm = T5LayerNorm(dim)
# initialize weights # initialize weights
self.apply(init_weights) # self.apply(init_weights)
def forward(self, ids, mask=None): def forward(self, ids, mask=None):
if self.cpu_offload: if self.cpu_offload:
...@@ -301,7 +295,7 @@ class T5Encoder(nn.Module): ...@@ -301,7 +295,7 @@ class T5Encoder(nn.Module):
if self.cpu_offload: if self.cpu_offload:
self.norm = self.norm.cpu() self.norm = self.norm.cpu()
x = self.dropout(x) x = self.dropout(x)
return x return x.to(torch.bfloat16)
class T5Decoder(nn.Module): class T5Decoder(nn.Module):
...@@ -480,10 +474,16 @@ class T5EncoderModel: ...@@ -480,10 +474,16 @@ class T5EncoderModel:
shard_fn=None, shard_fn=None,
cpu_offload=False, cpu_offload=False,
offload_granularity="model", offload_granularity="model",
t5_quantized=False,
t5_quantized_ckpt=None,
quant_scheme=None,
): ):
self.text_len = text_len self.text_len = text_len
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
if t5_quantized_ckpt is not None and t5_quantized:
self.checkpoint_path = t5_quantized_ckpt
else:
self.checkpoint_path = checkpoint_path self.checkpoint_path = checkpoint_path
self.tokenizer_path = tokenizer_path self.tokenizer_path = tokenizer_path
self.offload_granularity = offload_granularity self.offload_granularity = offload_granularity
...@@ -493,20 +493,23 @@ class T5EncoderModel: ...@@ -493,20 +493,23 @@ class T5EncoderModel:
if self.cpu_offload: if self.cpu_offload:
assert self.offload_granularity in ["block", "model"] assert self.offload_granularity in ["block", "model"]
# init model
model = ( model = (
umt5_xxl( umt5_xxl(
encoder_only=True, encoder_only=True,
return_tokenizer=False, return_tokenizer=False,
dtype=dtype, dtype=dtype,
device=device, device=device,
cpu_offload=cpu_offload if self.offload_granularity == "block" else False, cpu_offload=(cpu_offload if self.offload_granularity == "block" else False),
quantized=t5_quantized,
quant_scheme=quant_scheme,
) )
.eval() .eval()
.requires_grad_(False) .requires_grad_(False)
) )
logging.info(f"loading {checkpoint_path}")
model.load_state_dict(torch.load(checkpoint_path, map_location="cpu", weights_only=True)) logger.info(f"Loading weights from {self.checkpoint_path}")
model.load_state_dict(torch.load(self.checkpoint_path, map_location="cpu", weights_only=True))
self.model = model self.model = model
if shard_fn is not None: if shard_fn is not None:
self.model = shard_fn(self.model, sync_module_states=False) self.model = shard_fn(self.model, sync_module_states=False)
......
...@@ -9,10 +9,9 @@ import torch.nn.functional as F ...@@ -9,10 +9,9 @@ import torch.nn.functional as F
import torchvision.transforms as T import torchvision.transforms as T
from lightx2v.attentions import attention from lightx2v.attentions import attention
from lightx2v.models.input_encoders.hf.t5.tokenizer import HuggingfaceTokenizer
from loguru import logger from loguru import logger
from lightx2v.models.input_encoders.hf.q_linear import QuantLinearInt8
from .xlm_roberta import XLMRoberta
__all__ = [ __all__ = [
"XLMRobertaCLIP", "XLMRobertaCLIP",
...@@ -48,7 +47,7 @@ class LayerNorm(nn.LayerNorm): ...@@ -48,7 +47,7 @@ class LayerNorm(nn.LayerNorm):
class SelfAttention(nn.Module): class SelfAttention(nn.Module):
def __init__(self, dim, num_heads, causal=False, attn_dropout=0.0, proj_dropout=0.0): def __init__(self, dim, num_heads, causal=False, attn_dropout=0.0, proj_dropout=0.0, quantized=False, quant_scheme=None):
assert dim % num_heads == 0 assert dim % num_heads == 0
super().__init__() super().__init__()
self.dim = dim self.dim = dim
...@@ -59,8 +58,14 @@ class SelfAttention(nn.Module): ...@@ -59,8 +58,14 @@ class SelfAttention(nn.Module):
self.proj_dropout = proj_dropout self.proj_dropout = proj_dropout
# layers # layers
self.to_qkv = nn.Linear(dim, dim * 3) if quantized:
self.proj = nn.Linear(dim, dim) if quant_scheme == "int8":
linear_cls = QuantLinearInt8
else:
linear_cls = nn.Linear
self.to_qkv = linear_cls(dim, dim * 3)
self.proj = linear_cls(dim, dim)
def forward(self, x): def forward(self, x):
""" """
...@@ -86,7 +91,6 @@ class SwiGLU(nn.Module): ...@@ -86,7 +91,6 @@ class SwiGLU(nn.Module):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
self.mid_dim = mid_dim self.mid_dim = mid_dim
# layers # layers
self.fc1 = nn.Linear(dim, mid_dim) self.fc1 = nn.Linear(dim, mid_dim)
self.fc2 = nn.Linear(dim, mid_dim) self.fc2 = nn.Linear(dim, mid_dim)
...@@ -99,7 +103,7 @@ class SwiGLU(nn.Module): ...@@ -99,7 +103,7 @@ class SwiGLU(nn.Module):
class AttentionBlock(nn.Module): class AttentionBlock(nn.Module):
def __init__(self, dim, mlp_ratio, num_heads, post_norm=False, causal=False, activation="quick_gelu", attn_dropout=0.0, proj_dropout=0.0, norm_eps=1e-5): def __init__(self, dim, mlp_ratio, num_heads, post_norm=False, causal=False, activation="quick_gelu", attn_dropout=0.0, proj_dropout=0.0, norm_eps=1e-5, quantized=False, quant_scheme=None):
assert activation in ["quick_gelu", "gelu", "swi_glu"] assert activation in ["quick_gelu", "gelu", "swi_glu"]
super().__init__() super().__init__()
self.dim = dim self.dim = dim
...@@ -110,13 +114,19 @@ class AttentionBlock(nn.Module): ...@@ -110,13 +114,19 @@ class AttentionBlock(nn.Module):
self.norm_eps = norm_eps self.norm_eps = norm_eps
# layers # layers
if quantized:
if quant_scheme == "int8":
linear_cls = QuantLinearInt8
else:
linear_cls = nn.Linear
self.norm1 = LayerNorm(dim, eps=norm_eps) self.norm1 = LayerNorm(dim, eps=norm_eps)
self.attn = SelfAttention(dim, num_heads, causal, attn_dropout, proj_dropout) self.attn = SelfAttention(dim, num_heads, causal, attn_dropout, proj_dropout, quantized, quant_scheme)
self.norm2 = LayerNorm(dim, eps=norm_eps) self.norm2 = LayerNorm(dim, eps=norm_eps)
if activation == "swi_glu": if activation == "swi_glu":
self.mlp = SwiGLU(dim, int(dim * mlp_ratio)) self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
else: else:
self.mlp = nn.Sequential(nn.Linear(dim, int(dim * mlp_ratio)), QuickGELU() if activation == "quick_gelu" else nn.GELU(), nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout)) self.mlp = nn.Sequential(linear_cls(dim, int(dim * mlp_ratio)), QuickGELU() if activation == "quick_gelu" else nn.GELU(), linear_cls(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
def forward(self, x): def forward(self, x):
if self.post_norm: if self.post_norm:
...@@ -189,6 +199,8 @@ class VisionTransformer(nn.Module): ...@@ -189,6 +199,8 @@ class VisionTransformer(nn.Module):
proj_dropout=0.0, proj_dropout=0.0,
embedding_dropout=0.0, embedding_dropout=0.0,
norm_eps=1e-5, norm_eps=1e-5,
quantized=False,
quant_scheme=None,
): ):
if image_size % patch_size != 0: if image_size % patch_size != 0:
logger.info("[WARNING] image_size is not divisible by patch_size", flush=True) logger.info("[WARNING] image_size is not divisible by patch_size", flush=True)
...@@ -217,7 +229,9 @@ class VisionTransformer(nn.Module): ...@@ -217,7 +229,9 @@ class VisionTransformer(nn.Module):
# transformer # transformer
self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
self.transformer = nn.Sequential(*[AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False, activation, attn_dropout, proj_dropout, norm_eps) for _ in range(num_layers)]) self.transformer = nn.Sequential(
*[AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False, activation, attn_dropout, proj_dropout, norm_eps, quantized, quant_scheme) for _ in range(num_layers)]
)
self.post_norm = LayerNorm(dim, eps=norm_eps) self.post_norm = LayerNorm(dim, eps=norm_eps)
# head # head
...@@ -252,28 +266,6 @@ class VisionTransformer(nn.Module): ...@@ -252,28 +266,6 @@ class VisionTransformer(nn.Module):
return x return x
class XLMRobertaWithHead(XLMRoberta):
def __init__(self, **kwargs):
self.out_dim = kwargs.pop("out_dim")
super().__init__(**kwargs)
# head
mid_dim = (self.dim + self.out_dim) // 2
self.head = nn.Sequential(nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(), nn.Linear(mid_dim, self.out_dim, bias=False))
def forward(self, ids):
# xlm-roberta
x = super().forward(ids)
# average pooling
mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)
x = (x * mask).sum(dim=1) / mask.sum(dim=1)
# head
x = self.head(x)
return x
class XLMRobertaCLIP(nn.Module): class XLMRobertaCLIP(nn.Module):
def __init__( def __init__(
self, self,
...@@ -292,15 +284,12 @@ class XLMRobertaCLIP(nn.Module): ...@@ -292,15 +284,12 @@ class XLMRobertaCLIP(nn.Module):
max_text_len=514, max_text_len=514,
type_size=1, type_size=1,
pad_id=1, pad_id=1,
text_dim=1024,
text_heads=16,
text_layers=24,
text_post_norm=True,
text_dropout=0.1,
attn_dropout=0.0, attn_dropout=0.0,
proj_dropout=0.0, proj_dropout=0.0,
embedding_dropout=0.0, embedding_dropout=0.0,
norm_eps=1e-5, norm_eps=1e-5,
quantized=False,
quant_scheme=None,
): ):
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
...@@ -317,10 +306,6 @@ class XLMRobertaCLIP(nn.Module): ...@@ -317,10 +306,6 @@ class XLMRobertaCLIP(nn.Module):
self.max_text_len = max_text_len self.max_text_len = max_text_len
self.type_size = type_size self.type_size = type_size
self.pad_id = pad_id self.pad_id = pad_id
self.text_dim = text_dim
self.text_heads = text_heads
self.text_layers = text_layers
self.text_post_norm = text_post_norm
self.norm_eps = norm_eps self.norm_eps = norm_eps
# models # models
...@@ -340,40 +325,11 @@ class XLMRobertaCLIP(nn.Module): ...@@ -340,40 +325,11 @@ class XLMRobertaCLIP(nn.Module):
proj_dropout=proj_dropout, proj_dropout=proj_dropout,
embedding_dropout=embedding_dropout, embedding_dropout=embedding_dropout,
norm_eps=norm_eps, norm_eps=norm_eps,
) quantized=quantized,
self.textual = XLMRobertaWithHead( quant_scheme=quant_scheme,
vocab_size=vocab_size,
max_seq_len=max_text_len,
type_size=type_size,
pad_id=pad_id,
dim=text_dim,
out_dim=embed_dim,
num_heads=text_heads,
num_layers=text_layers,
post_norm=text_post_norm,
dropout=text_dropout,
) )
self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([])) self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
def forward(self, imgs, txt_ids):
"""
imgs: [B, 3, H, W] of torch.float32.
- mean: [0.48145466, 0.4578275, 0.40821073]
- std: [0.26862954, 0.26130258, 0.27577711]
txt_ids: [B, L] of torch.long.
Encoded by data.CLIPTokenizer.
"""
xi = self.visual(imgs)
xt = self.textual(txt_ids)
return xi, xt
def param_groups(self):
groups = [
{"params": [p for n, p in self.named_parameters() if "norm" in n or n.endswith("bias")], "weight_decay": 0.0},
{"params": [p for n, p in self.named_parameters() if not ("norm" in n or n.endswith("bias"))]},
]
return groups
def _clip(pretrained=False, pretrained_name=None, model_cls=XLMRobertaCLIP, return_transforms=False, return_tokenizer=False, tokenizer_padding="eos", dtype=torch.float32, device="cpu", **kwargs): def _clip(pretrained=False, pretrained_name=None, model_cls=XLMRobertaCLIP, return_transforms=False, return_tokenizer=False, tokenizer_padding="eos", dtype=torch.float32, device="cpu", **kwargs):
# init a model on device # init a model on device
...@@ -414,11 +370,6 @@ def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-r ...@@ -414,11 +370,6 @@ def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-r
max_text_len=514, max_text_len=514,
type_size=1, type_size=1,
pad_id=1, pad_id=1,
text_dim=1024,
text_heads=16,
text_layers=24,
text_post_norm=True,
text_dropout=0.1,
attn_dropout=0.0, attn_dropout=0.0,
proj_dropout=0.0, proj_dropout=0.0,
embedding_dropout=0.0, embedding_dropout=0.0,
...@@ -428,20 +379,29 @@ def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-r ...@@ -428,20 +379,29 @@ def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-r
class CLIPModel: class CLIPModel:
def __init__(self, dtype, device, checkpoint_path, tokenizer_path): def __init__(self, dtype, device, checkpoint_path, clip_quantized, clip_quantized_ckpt, quant_scheme):
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
self.quantized = clip_quantized
if self.quantized:
self.checkpoint_path = clip_quantized_ckpt
else:
self.checkpoint_path = checkpoint_path self.checkpoint_path = checkpoint_path
self.tokenizer_path = tokenizer_path
logger.info(f"Loading weights from {self.checkpoint_path}")
# init model # init model
self.model, self.transforms = clip_xlm_roberta_vit_h_14(pretrained=False, return_transforms=True, return_tokenizer=False, dtype=dtype, device=device) self.model, self.transforms = clip_xlm_roberta_vit_h_14(
pretrained=False, return_transforms=True, return_tokenizer=False, dtype=dtype, device=device, quantized=self.quantized, quant_scheme=quant_scheme
)
self.model = self.model.eval().requires_grad_(False) self.model = self.model.eval().requires_grad_(False)
logging.info(f"loading {checkpoint_path}")
self.model.load_state_dict(torch.load(checkpoint_path, map_location="cpu", weights_only=True))
# init tokenizer weight_dict = torch.load(self.checkpoint_path, map_location="cpu", weights_only=True)
self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=self.model.max_text_len - 2, clean="whitespace") keys = list(weight_dict.keys())
for key in keys:
if "textual" in key:
weight_dict.pop(key)
self.model.load_state_dict(weight_dict)
def visual(self, videos, args): def visual(self, videos, args):
if args.cpu_offload: if args.cpu_offload:
......
# Modified from transformers.models.xlm_roberta.modeling_xlm_roberta
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
__all__ = ["XLMRoberta", "xlm_roberta_large"]
class SelfAttention(nn.Module):
def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
assert dim % num_heads == 0
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.eps = eps
# layers
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask):
"""
x: [B, L, C].
"""
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
# compute query, key, value
q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
# compute attention
p = self.dropout.p if self.training else 0.0
x = F.scaled_dot_product_attention(q, k, v, mask, p)
x = x.permute(0, 2, 1, 3).reshape(b, s, c)
# output
x = self.o(x)
x = self.dropout(x)
return x
class AttentionBlock(nn.Module):
def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.post_norm = post_norm
self.eps = eps
# layers
self.attn = SelfAttention(dim, num_heads, dropout, eps)
self.norm1 = nn.LayerNorm(dim, eps=eps)
self.ffn = nn.Sequential(nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim), nn.Dropout(dropout))
self.norm2 = nn.LayerNorm(dim, eps=eps)
def forward(self, x, mask):
if self.post_norm:
x = self.norm1(x + self.attn(x, mask))
x = self.norm2(x + self.ffn(x))
else:
x = x + self.attn(self.norm1(x), mask)
x = x + self.ffn(self.norm2(x))
return x
class XLMRoberta(nn.Module):
"""
XLMRobertaModel with no pooler and no LM head.
"""
def __init__(self, vocab_size=250002, max_seq_len=514, type_size=1, pad_id=1, dim=1024, num_heads=16, num_layers=24, post_norm=True, dropout=0.1, eps=1e-5):
super().__init__()
self.vocab_size = vocab_size
self.max_seq_len = max_seq_len
self.type_size = type_size
self.pad_id = pad_id
self.dim = dim
self.num_heads = num_heads
self.num_layers = num_layers
self.post_norm = post_norm
self.eps = eps
# embeddings
self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id)
self.type_embedding = nn.Embedding(type_size, dim)
self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id)
self.dropout = nn.Dropout(dropout)
# blocks
self.blocks = nn.ModuleList([AttentionBlock(dim, num_heads, post_norm, dropout, eps) for _ in range(num_layers)])
# norm layer
self.norm = nn.LayerNorm(dim, eps=eps)
def forward(self, ids):
"""
ids: [B, L] of torch.LongTensor.
"""
b, s = ids.shape
mask = ids.ne(self.pad_id).long()
# embeddings
x = self.token_embedding(ids) + self.type_embedding(torch.zeros_like(ids)) + self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)
if self.post_norm:
x = self.norm(x)
x = self.dropout(x)
# blocks
mask = torch.where(mask.view(b, 1, 1, s).gt(0), 0.0, torch.finfo(x.dtype).min)
for block in self.blocks:
x = block(x, mask)
# output
if not self.post_norm:
x = self.norm(x)
return x
def xlm_roberta_large(pretrained=False, return_tokenizer=False, device="cpu", **kwargs):
"""
XLMRobertaLarge adapted from Huggingface.
"""
# params
cfg = dict(vocab_size=250002, max_seq_len=514, type_size=1, pad_id=1, dim=1024, num_heads=16, num_layers=24, post_norm=True, dropout=0.1, eps=1e-5)
cfg.update(**kwargs)
# init a model on device
with torch.device(device):
model = XLMRoberta(**cfg)
return model
import os import os
import sys
import torch import torch
import json
from lightx2v.models.networks.hunyuan.weights.pre_weights import HunyuanPreWeights from lightx2v.models.networks.hunyuan.weights.pre_weights import HunyuanPreWeights
from lightx2v.models.networks.hunyuan.weights.post_weights import HunyuanPostWeights from lightx2v.models.networks.hunyuan.weights.post_weights import HunyuanPostWeights
from lightx2v.models.networks.hunyuan.weights.transformer_weights import HunyuanTransformerWeights from lightx2v.models.networks.hunyuan.weights.transformer_weights import HunyuanTransformerWeights
...@@ -13,6 +13,7 @@ import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap ...@@ -13,6 +13,7 @@ import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap
import lightx2v.attentions.distributed.ring.wrap as ring_dist_wrap import lightx2v.attentions.distributed.ring.wrap as ring_dist_wrap
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from loguru import logger from loguru import logger
from safetensors import safe_open
class HunyuanModel: class HunyuanModel:
...@@ -25,13 +26,15 @@ class HunyuanModel: ...@@ -25,13 +26,15 @@ class HunyuanModel:
self.config = config self.config = config
self.device = device self.device = device
self.args = args self.args = args
self.dit_quantized = self.config.mm_config.get("mm_type", "Default") != "Default"
self.dit_quantized_ckpt = self.config.get("dit_quantized_ckpt", None)
self.weight_auto_quant = self.config.mm_config.get("weight_auto_quant", False)
if self.dit_quantized:
assert self.weight_auto_quant or self.dit_quantized_ckpt is not None
self._init_infer_class() self._init_infer_class()
self._init_weights() self._init_weights()
if GET_RUNNING_FLAG() == "save_naive_quant":
assert self.config.get("quant_model_path") is not None, "quant_model_path is None"
self.save_weights(self.config.quant_model_path)
sys.exit(0)
self._init_infer() self._init_infer()
if config["parallel_attn_type"]: if config["parallel_attn_type"]:
...@@ -45,18 +48,6 @@ class HunyuanModel: ...@@ -45,18 +48,6 @@ class HunyuanModel:
if self.config["cpu_offload"]: if self.config["cpu_offload"]:
self.to_cpu() self.to_cpu()
def _init_infer_class(self):
self.pre_infer_class = HunyuanPreInfer
self.post_infer_class = HunyuanPostInfer
if self.config["feature_caching"] == "NoCaching":
self.transformer_infer_class = HunyuanTransformerInfer
elif self.config["feature_caching"] == "TaylorSeer":
self.transformer_infer_class = HunyuanTransformerInferTaylorCaching
elif self.config["feature_caching"] == "Tea":
self.transformer_infer_class = HunyuanTransformerInferTeaCaching
else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
def _load_ckpt(self): def _load_ckpt(self):
if self.args.task == "t2v": if self.args.task == "t2v":
ckpt_path = os.path.join(self.model_path, "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt") ckpt_path = os.path.join(self.model_path, "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt")
...@@ -65,18 +56,41 @@ class HunyuanModel: ...@@ -65,18 +56,41 @@ class HunyuanModel:
weight_dict = torch.load(ckpt_path, map_location=self.device, weights_only=True)["module"] weight_dict = torch.load(ckpt_path, map_location=self.device, weights_only=True)["module"]
return weight_dict return weight_dict
def _load_ckpt_quant_model(self): def _load_quant_ckpt(self):
assert self.config.get("quant_model_path") is not None, "quant_model_path is None" ckpt_path = self.config.dit_quantized_ckpt
logger.info(f"Loading quant model from {self.config.quant_model_path}") logger.info(f"Loading quant dit model from {ckpt_path}")
quant_weights_path = os.path.join(self.config.quant_model_path, "quant_weights.pth")
weight_dict = torch.load(quant_weights_path, map_location=self.device, weights_only=True) if ckpt_path.endswith(".pth"):
logger.info(f"Loading {ckpt_path} as PyTorch model.")
weight_dict = torch.load(ckpt_path, map_location=self.device, weights_only=True)
else:
index_files = [f for f in os.listdir(ckpt_path) if f.endswith(".index.json")]
if not index_files:
raise FileNotFoundError(f"No .pth file or *.index.json found in {ckpt_path}")
index_path = os.path.join(ckpt_path, index_files[0])
logger.info(f" Using safetensors index: {index_path}")
with open(index_path, "r") as f:
index_data = json.load(f)
weight_dict = {}
for filename in set(index_data["weight_map"].values()):
safetensor_path = os.path.join(ckpt_path, filename)
with safe_open(safetensor_path, framework="pt", device=str(self.device)) as f:
logger.info(f"Loading weights from {safetensor_path}")
for k in f.keys():
weight_dict[k] = f.get_tensor(k)
if weight_dict[k].dtype == torch.float:
weight_dict[k] = weight_dict[k].to(torch.bfloat16)
return weight_dict return weight_dict
def _init_weights(self): def _init_weights(self):
if GET_RUNNING_FLAG() == "save_naive_quant" or self.config["mm_config"].get("weight_auto_quant", False) or self.config["mm_config"].get("mm_type", "Default") == "Default": if not self.dit_quantized or self.weight_auto_quant:
weight_dict = self._load_ckpt() weight_dict = self._load_ckpt()
else: else:
weight_dict = self._load_ckpt_quant_model() weight_dict = self._load_quant_ckpt()
# init weights # init weights
self.pre_weight = self.pre_weight_class(self.config) self.pre_weight = self.pre_weight_class(self.config)
self.post_weight = self.post_weight_class(self.config) self.post_weight = self.post_weight_class(self.config)
...@@ -146,3 +160,15 @@ class HunyuanModel: ...@@ -146,3 +160,15 @@ class HunyuanModel:
self.scheduler.cnt += 1 self.scheduler.cnt += 1
if self.scheduler.cnt == self.scheduler.num_steps: if self.scheduler.cnt == self.scheduler.num_steps:
self.scheduler.cnt = 0 self.scheduler.cnt = 0
def _init_infer_class(self):
self.pre_infer_class = HunyuanPreInfer
self.post_infer_class = HunyuanPostInfer
if self.config["feature_caching"] == "NoCaching":
self.transformer_infer_class = HunyuanTransformerInfer
elif self.config["feature_caching"] == "TaylorSeer":
self.transformer_infer_class = HunyuanTransformerInferTaylorCaching
elif self.config["feature_caching"] == "Tea":
self.transformer_infer_class = HunyuanTransformerInferTeaCaching
else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
import torch import torch
from .utils import compute_freqs, compute_freqs_dist, apply_rotary_emb from .utils import compute_freqs, compute_freqs_dist, apply_rotary_emb
from lightx2v.common.offload.manager import WeightAsyncStreamManager from lightx2v.common.offload.manager import WeightAsyncStreamManager, LazyWeightAsyncStreamManager
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
...@@ -24,8 +24,25 @@ class WanTransformerInfer: ...@@ -24,8 +24,25 @@ class WanTransformerInfer:
if offload_granularity == "block": if offload_granularity == "block":
self.infer_func = self._infer_with_offload self.infer_func = self._infer_with_offload
elif offload_granularity == "phase": elif offload_granularity == "phase":
if not self.config.get("lazy_load", False):
self.infer_func = self._infer_with_phases_offload self.infer_func = self._infer_with_phases_offload
self.weights_stream_mgr = WeightAsyncStreamManager(blocks_num=self.blocks_num, offload_ratio=offload_ratio, phases_num=self.phases_num) else:
self.infer_func = self._infer_with_phases_lazy_offload
if not self.config.get("lazy_load", False):
self.weights_stream_mgr = WeightAsyncStreamManager(
blocks_num=self.blocks_num,
offload_ratio=offload_ratio,
phases_num=self.phases_num,
)
else:
self.weights_stream_mgr = LazyWeightAsyncStreamManager(
blocks_num=self.blocks_num,
offload_ratio=offload_ratio,
phases_num=self.phases_num,
num_disk_workers=self.config.get("num_disk_workers", 2),
max_memory=self.config.get("max_memory", 2),
)
else: else:
self.infer_func = self._infer_without_offload self.infer_func = self._infer_without_offload
...@@ -33,10 +50,7 @@ class WanTransformerInfer: ...@@ -33,10 +50,7 @@ class WanTransformerInfer:
self.scheduler = scheduler self.scheduler = scheduler
def _calculate_q_k_len(self, q, k_lens): def _calculate_q_k_len(self, q, k_lens):
# Handle query and key lengths (use `q_lens` and `k_lens` or set them to Lq and Lk if None)
q_lens = torch.tensor([q.size(0)], dtype=torch.int32, device=q.device) q_lens = torch.tensor([q.size(0)], dtype=torch.int32, device=q.device)
# We don't have a batch dimension anymore, so directly use the `q_lens` and `k_lens` values
cu_seqlens_q = torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32) cu_seqlens_q = torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32)
cu_seqlens_k = torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32) cu_seqlens_k = torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32)
return cu_seqlens_q, cu_seqlens_k return cu_seqlens_q, cu_seqlens_k
...@@ -45,6 +59,7 @@ class WanTransformerInfer: ...@@ -45,6 +59,7 @@ class WanTransformerInfer:
def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
return self.infer_func(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context) return self.infer_func(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
# bug
def _infer_with_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): def _infer_with_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
for block_idx in range(self.blocks_num): for block_idx in range(self.blocks_num):
if block_idx == 0: if block_idx == 0:
...@@ -107,15 +122,77 @@ class WanTransformerInfer: ...@@ -107,15 +122,77 @@ class WanTransformerInfer:
elif cur_phase_idx == 2: elif cur_phase_idx == 2:
x = self._infer_ffn(cur_phase, x, c_shift_msa, c_scale_msa, c_gate_msa) x = self._infer_ffn(cur_phase, x, c_shift_msa, c_scale_msa, c_gate_msa)
is_last_phase = block_idx == weights.blocks_num - 1 and phase_idx == 2 is_last_phase = block_idx == weights.blocks_num - 1 and phase_idx == self.phases_num - 1
if not is_last_phase: if not is_last_phase:
next_block_idx = block_idx + 1 if cur_phase_idx == 2 else block_idx next_block_idx = block_idx + 1 if phase_idx == self.phases_num - 1 else block_idx
next_phase_idx = (cur_phase_idx + 1) % 3 next_phase_idx = (phase_idx + 1) % self.phases_num
self.weights_stream_mgr.prefetch_phase(next_block_idx, next_phase_idx, weights.blocks)
self.weights_stream_mgr.swap_phases()
weights.blocks[block_idx].modulation.to_cpu()
torch.cuda.empty_cache()
return x
def _infer_with_phases_lazy_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
self.weights_stream_mgr.prefetch_weights_from_disk(weights)
self.weights_stream_mgr._async_prefetch_block(weights)
for block_idx in range(weights.blocks_num):
with torch.cuda.stream(self.weights_stream_mgr.compute_stream):
weights.blocks[block_idx].modulation.to_cuda()
if embed0.dim() == 3:
modulation = weights.blocks[block_idx].modulation.tensor.unsqueeze(2)
current_embed0 = (modulation + embed0).chunk(6, dim=1)
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = [ei.squeeze(1) for ei in current_embed0]
elif embed0.dim() == 2:
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (weights.blocks[block_idx].modulation.tensor + embed0).chunk(6, dim=1)
for phase_idx in range(self.weights_stream_mgr.phases_num):
if block_idx == 0 and phase_idx == 0:
obj_key = (block_idx, phase_idx)
phase = self.weights_stream_mgr.pin_memory_buffer.get(obj_key)
phase.to_cuda()
self.weights_stream_mgr.active_weights[0] = (obj_key, phase)
with torch.cuda.stream(self.weights_stream_mgr.compute_stream):
(
(
_,
cur_phase_idx,
),
cur_phase,
) = self.weights_stream_mgr.active_weights[0]
if cur_phase_idx == 0:
x = self._infer_self_attn(
cur_phase,
x,
shift_msa,
scale_msa,
gate_msa,
grid_sizes,
freqs,
seq_lens,
)
elif cur_phase_idx == 1:
x = self._infer_cross_attn(cur_phase, x, context)
elif cur_phase_idx == 2:
x = self._infer_ffn(cur_phase, x, c_shift_msa, c_scale_msa, c_gate_msa)
if not (block_idx == weights.blocks_num - 1 and phase_idx == self.phases_num - 1):
next_block_idx = block_idx + 1 if phase_idx == self.phases_num - 1 else block_idx
next_phase_idx = (phase_idx + 1) % self.weights_stream_mgr.phases_num
self.weights_stream_mgr.prefetch_phase(next_block_idx, next_phase_idx, weights.blocks) self.weights_stream_mgr.prefetch_phase(next_block_idx, next_phase_idx, weights.blocks)
self.weights_stream_mgr.swap_phases() self.weights_stream_mgr.swap_phases()
weights.blocks[block_idx].modulation.to_cpu() weights.blocks[block_idx].modulation.to_cpu()
self.weights_stream_mgr._async_prefetch_block(weights)
torch.cuda.empty_cache() torch.cuda.empty_cache()
......
...@@ -31,14 +31,16 @@ class WanModel: ...@@ -31,14 +31,16 @@ class WanModel:
def __init__(self, model_path, config, device): def __init__(self, model_path, config, device):
self.model_path = model_path self.model_path = model_path
self.config = config self.config = config
self.dit_quantized = self.config.mm_config.get("mm_type", "Default") != "Default"
self.dit_quantized_ckpt = self.config.get("dit_quantized_ckpt", None)
self.weight_auto_quant = self.config.mm_config.get("weight_auto_quant", False)
if self.dit_quantized:
assert self.weight_auto_quant or self.dit_quantized_ckpt is not None
self.device = device self.device = device
self._init_infer_class() self._init_infer_class()
self._init_weights() self._init_weights()
if GET_RUNNING_FLAG() == "save_naive_quant":
assert self.config.get("quant_model_path") is not None, "quant_model_path is None"
self.save_weights(self.config.quant_model_path)
sys.exit(0)
self._init_infer() self._init_infer()
self.current_lora = None self.current_lora = None
...@@ -64,9 +66,9 @@ class WanModel: ...@@ -64,9 +66,9 @@ class WanModel:
use_bfloat16 = self.config.get("use_bfloat16", True) use_bfloat16 = self.config.get("use_bfloat16", True)
with safe_open(file_path, framework="pt") as f: with safe_open(file_path, framework="pt") as f:
if use_bfloat16: if use_bfloat16:
tensor_dict = {key: f.get_tensor(key).to(torch.bfloat16).pin_memory().to(self.device) for key in f.keys()} tensor_dict = {key: f.get_tensor(key).to(torch.bfloat16).to(self.device) for key in f.keys()}
else: else:
tensor_dict = {key: f.get_tensor(key).pin_memory().to(self.device) for key in f.keys()} tensor_dict = {key: f.get_tensor(key).to(self.device) for key in f.keys()}
return tensor_dict return tensor_dict
def _load_ckpt(self): def _load_ckpt(self):
...@@ -82,22 +84,19 @@ class WanModel: ...@@ -82,22 +84,19 @@ class WanModel:
return weight_dict return weight_dict
def _load_quant_ckpt(self): def _load_quant_ckpt(self):
assert self.config.get("quant_model_path") is not None, "quant_model_path is None" ckpt_path = self.config.dit_quantized_ckpt
ckpt_path = self.config.quant_model_path logger.info(f"Loading quant dit model from {ckpt_path}")
logger.info(f"Loading quant model from {ckpt_path}")
quant_pth_file = os.path.join(ckpt_path, "quant_weights.pth")
if os.path.exists(quant_pth_file): if ckpt_path.endswith(".pth"):
logger.info("Found quant_weights.pth, loading as PyTorch model.") logger.info(f"Loading {ckpt_path} as PyTorch model.")
weight_dict = torch.load(quant_pth_file, map_location=self.device, weights_only=True) weight_dict = torch.load(ckpt_path, map_location=self.device, weights_only=True)
else: else:
index_files = [f for f in os.listdir(ckpt_path) if f.endswith(".index.json")] index_files = [f for f in os.listdir(ckpt_path) if f.endswith(".index.json")]
if not index_files: if not index_files:
raise FileNotFoundError(f"No quant_weights.pth or *.index.json found in {ckpt_path}") raise FileNotFoundError(f"No .pth file or *.index.json found in {ckpt_path}")
index_path = os.path.join(ckpt_path, index_files[0]) index_path = os.path.join(ckpt_path, index_files[0])
logger.info(f"quant_weights.pth not found. Using safetensors index: {index_path}") logger.info(f" Using safetensors index: {index_path}")
with open(index_path, "r") as f: with open(index_path, "r") as f:
index_data = json.load(f) index_data = json.load(f)
...@@ -114,14 +113,48 @@ class WanModel: ...@@ -114,14 +113,48 @@ class WanModel:
return weight_dict return weight_dict
def _load_quant_split_ckpt(self):
lazy_load_model_path = self.config.dit_quantized_ckpt
logger.info(f"Loading splited quant model from {lazy_load_model_path}")
pre_post_weight_dict, transformer_weight_dict = {}, {}
safetensor_path = os.path.join(lazy_load_model_path, "non_block.safetensors")
with safe_open(safetensor_path, framework="pt", device=str(self.device)) as f:
for k in f.keys():
pre_post_weight_dict[k] = f.get_tensor(k)
if pre_post_weight_dict[k].dtype == torch.float:
pre_post_weight_dict[k] = pre_post_weight_dict[k].to(torch.bfloat16)
safetensors_pattern = os.path.join(lazy_load_model_path, "block_*.safetensors")
safetensors_files = glob.glob(safetensors_pattern)
if not safetensors_files:
raise FileNotFoundError(f"No .safetensors files found in directory: {lazy_load_model_path}")
for file_path in safetensors_files:
with safe_open(file_path, framework="pt") as f:
for k in f.keys():
if "modulation" in k:
transformer_weight_dict[k] = f.get_tensor(k)
if transformer_weight_dict[k].dtype == torch.float:
transformer_weight_dict[k] = transformer_weight_dict[k].to(torch.bfloat16)
return pre_post_weight_dict, transformer_weight_dict
def _init_weights(self, weight_dict=None): def _init_weights(self, weight_dict=None):
if weight_dict is None: if weight_dict is None:
if GET_RUNNING_FLAG() == "save_naive_quant" or self.config["mm_config"].get("weight_auto_quant", False) or self.config["mm_config"].get("mm_type", "Default") == "Default": if not self.dit_quantized or self.weight_auto_quant:
self.original_weight_dict = self._load_ckpt() self.original_weight_dict = self._load_ckpt()
else: else:
if not self.config.get("lazy_load", False):
self.original_weight_dict = self._load_quant_ckpt() self.original_weight_dict = self._load_quant_ckpt()
else:
(
self.original_weight_dict,
self.transformer_weight_dict,
) = self._load_quant_split_ckpt()
else: else:
self.original_weight_dict = weight_dict self.original_weight_dict = weight_dict
# init weights # init weights
self.pre_weight = self.pre_weight_class(self.config) self.pre_weight = self.pre_weight_class(self.config)
self.post_weight = self.post_weight_class(self.config) self.post_weight = self.post_weight_class(self.config)
...@@ -129,6 +162,9 @@ class WanModel: ...@@ -129,6 +162,9 @@ class WanModel:
# load weights # load weights
self.pre_weight.load(self.original_weight_dict) self.pre_weight.load(self.original_weight_dict)
self.post_weight.load(self.original_weight_dict) self.post_weight.load(self.original_weight_dict)
if hasattr(self, "transformer_weight_dict"):
self.transformer_weights.load(self.transformer_weight_dict)
else:
self.transformer_weights.load(self.original_weight_dict) self.transformer_weights.load(self.original_weight_dict)
def _init_infer(self): def _init_infer(self):
...@@ -136,28 +172,6 @@ class WanModel: ...@@ -136,28 +172,6 @@ class WanModel:
self.post_infer = self.post_infer_class(self.config) self.post_infer = self.post_infer_class(self.config)
self.transformer_infer = self.transformer_infer_class(self.config) self.transformer_infer = self.transformer_infer_class(self.config)
def save_weights(self, save_path):
if not os.path.exists(save_path):
os.makedirs(save_path)
pre_state_dict = self.pre_weight.state_dict()
logger.info(pre_state_dict.keys())
post_state_dict = self.post_weight.state_dict()
logger.info(post_state_dict.keys())
transformer_state_dict = self.transformer_weights.state_dict()
logger.info(transformer_state_dict.keys())
save_dict = {}
save_dict.update(pre_state_dict)
save_dict.update(post_state_dict)
save_dict.update(transformer_state_dict)
save_path = os.path.join(save_path, "quant_weights.pth")
torch.save(save_dict, save_path)
logger.info(f"Save weights to {save_path}")
def set_scheduler(self, scheduler): def set_scheduler(self, scheduler):
self.scheduler = scheduler self.scheduler = scheduler
self.pre_infer.set_scheduler(scheduler) self.pre_infer.set_scheduler(scheduler)
......
import torch import torch
import os
from lightx2v.utils.registry_factory import ( from lightx2v.utils.registry_factory import (
MM_WEIGHT_REGISTER, MM_WEIGHT_REGISTER,
LN_WEIGHT_REGISTER, LN_WEIGHT_REGISTER,
...@@ -7,6 +8,7 @@ from lightx2v.utils.registry_factory import ( ...@@ -7,6 +8,7 @@ from lightx2v.utils.registry_factory import (
ATTN_WEIGHT_REGISTER, ATTN_WEIGHT_REGISTER,
) )
from lightx2v.common.modules.weight_module import WeightModule, WeightModuleList from lightx2v.common.modules.weight_module import WeightModule, WeightModuleList
from safetensors import safe_open
class WanTransformerWeights(WeightModule): class WanTransformerWeights(WeightModule):
...@@ -36,21 +38,28 @@ class WanTransformerAttentionBlock(WeightModule): ...@@ -36,21 +38,28 @@ class WanTransformerAttentionBlock(WeightModule):
"modulation", "modulation",
TENSOR_REGISTER["Default"](f"blocks.{self.block_index}.modulation"), TENSOR_REGISTER["Default"](f"blocks.{self.block_index}.modulation"),
) )
self.lazy_load = self.config.get("lazy_load", False)
if self.lazy_load:
lazy_load_path = os.path.join(self.config.dit_quantized_ckpt, f"block_{block_index}.safetensors")
self.lazy_load_file = safe_open(lazy_load_path, framework="pt", device="cpu")
else:
self.lazy_load_file = None
self.compute_phases = WeightModuleList( self.compute_phases = WeightModuleList(
[ [
WanSelfAttention(block_index, task, mm_type, config), WanSelfAttention(block_index, task, mm_type, config, self.lazy_load, self.lazy_load_file),
WanCrossAttention(block_index, task, mm_type, config), WanCrossAttention(block_index, task, mm_type, config, self.lazy_load, self.lazy_load_file),
WanFFN(block_index, task, mm_type, config), WanFFN(block_index, task, mm_type, config, self.lazy_load, self.lazy_load_file),
] ]
) )
self.add_module("compute_phases", self.compute_phases) self.add_module("compute_phases", self.compute_phases)
# i2v
class WanSelfAttention(WeightModule): class WanSelfAttention(WeightModule):
def __init__(self, block_index, task, mm_type, config): def __init__(self, block_index, task, mm_type, config, lazy_load, lazy_load_file):
super().__init__() super().__init__()
self.block_index = block_index self.block_index = block_index
self.mm_type = mm_type self.mm_type = mm_type
...@@ -59,11 +68,16 @@ class WanSelfAttention(WeightModule): ...@@ -59,11 +68,16 @@ class WanSelfAttention(WeightModule):
self.quant_method = config["mm_config"].get("quant_method", None) self.quant_method = config["mm_config"].get("quant_method", None)
self.sparge = config.get("sparge", False) self.sparge = config.get("sparge", False)
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
self.add_module( self.add_module(
"self_attn_q", "self_attn_q",
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"blocks.{self.block_index}.self_attn.q.weight", f"blocks.{self.block_index}.self_attn.q.weight",
f"blocks.{self.block_index}.self_attn.q.bias", f"blocks.{self.block_index}.self_attn.q.bias",
self.lazy_load,
self.lazy_load_file,
), ),
) )
self.add_module( self.add_module(
...@@ -71,6 +85,8 @@ class WanSelfAttention(WeightModule): ...@@ -71,6 +85,8 @@ class WanSelfAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"blocks.{self.block_index}.self_attn.k.weight", f"blocks.{self.block_index}.self_attn.k.weight",
f"blocks.{self.block_index}.self_attn.k.bias", f"blocks.{self.block_index}.self_attn.k.bias",
self.lazy_load,
self.lazy_load_file,
), ),
) )
self.add_module( self.add_module(
...@@ -78,6 +94,8 @@ class WanSelfAttention(WeightModule): ...@@ -78,6 +94,8 @@ class WanSelfAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"blocks.{self.block_index}.self_attn.v.weight", f"blocks.{self.block_index}.self_attn.v.weight",
f"blocks.{self.block_index}.self_attn.v.bias", f"blocks.{self.block_index}.self_attn.v.bias",
self.lazy_load,
self.lazy_load_file,
), ),
) )
self.add_module( self.add_module(
...@@ -85,15 +103,25 @@ class WanSelfAttention(WeightModule): ...@@ -85,15 +103,25 @@ class WanSelfAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"blocks.{self.block_index}.self_attn.o.weight", f"blocks.{self.block_index}.self_attn.o.weight",
f"blocks.{self.block_index}.self_attn.o.bias", f"blocks.{self.block_index}.self_attn.o.bias",
self.lazy_load,
self.lazy_load_file,
), ),
) )
self.add_module( self.add_module(
"self_attn_norm_q", "self_attn_norm_q",
RMS_WEIGHT_REGISTER["sgl-kernel"](f"blocks.{self.block_index}.self_attn.norm_q.weight"), RMS_WEIGHT_REGISTER["sgl-kernel"](
f"blocks.{self.block_index}.self_attn.norm_q.weight",
self.lazy_load,
self.lazy_load_file,
),
) )
self.add_module( self.add_module(
"self_attn_norm_k", "self_attn_norm_k",
RMS_WEIGHT_REGISTER["sgl-kernel"](f"blocks.{self.block_index}.self_attn.norm_k.weight"), RMS_WEIGHT_REGISTER["sgl-kernel"](
f"blocks.{self.block_index}.self_attn.norm_k.weight",
self.lazy_load,
self.lazy_load_file,
),
) )
if self.sparge: if self.sparge:
assert self.config["sparge_ckpt"], "sparge_ckpt must be set when sparge is True" assert self.config["sparge_ckpt"], "sparge_ckpt must be set when sparge is True"
...@@ -108,27 +136,39 @@ class WanSelfAttention(WeightModule): ...@@ -108,27 +136,39 @@ class WanSelfAttention(WeightModule):
if self.quant_method in ["smoothquant", "awq"]: if self.quant_method in ["smoothquant", "awq"]:
self.register_parameter( self.register_parameter(
"smooth_norm1_weight", "smooth_norm1_weight",
TENSOR_REGISTER["Default"](f"blocks.{self.block_index}.affine_norm1.weight"), TENSOR_REGISTER["Default"](
f"blocks.{self.block_index}.affine_norm1.weight",
self.lazy_load,
self.lazy_load_file,
),
) )
self.register_parameter( self.register_parameter(
"smooth_norm1_bias", "smooth_norm1_bias",
TENSOR_REGISTER["Default"](f"blocks.{self.block_index}.affine_norm1.bias"), TENSOR_REGISTER["Default"](
f"blocks.{self.block_index}.affine_norm1.bias",
self.lazy_load,
self.lazy_load_file,
),
) )
class WanCrossAttention(WeightModule): class WanCrossAttention(WeightModule):
def __init__(self, block_index, task, mm_type, config): def __init__(self, block_index, task, mm_type, config, lazy_load, lazy_load_file):
super().__init__() super().__init__()
self.block_index = block_index self.block_index = block_index
self.mm_type = mm_type self.mm_type = mm_type
self.task = task self.task = task
self.config = config self.config = config
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
self.add_module( self.add_module(
"norm3", "norm3",
LN_WEIGHT_REGISTER["Default"]( LN_WEIGHT_REGISTER["Default"](
f"blocks.{self.block_index}.norm3.weight", f"blocks.{self.block_index}.norm3.weight",
f"blocks.{self.block_index}.norm3.bias", f"blocks.{self.block_index}.norm3.bias",
self.lazy_load,
self.lazy_load_file,
eps=1e-6, eps=1e-6,
), ),
) )
...@@ -137,6 +177,8 @@ class WanCrossAttention(WeightModule): ...@@ -137,6 +177,8 @@ class WanCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"blocks.{self.block_index}.cross_attn.q.weight", f"blocks.{self.block_index}.cross_attn.q.weight",
f"blocks.{self.block_index}.cross_attn.q.bias", f"blocks.{self.block_index}.cross_attn.q.bias",
self.lazy_load,
self.lazy_load_file,
), ),
) )
self.add_module( self.add_module(
...@@ -144,6 +186,8 @@ class WanCrossAttention(WeightModule): ...@@ -144,6 +186,8 @@ class WanCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"blocks.{self.block_index}.cross_attn.k.weight", f"blocks.{self.block_index}.cross_attn.k.weight",
f"blocks.{self.block_index}.cross_attn.k.bias", f"blocks.{self.block_index}.cross_attn.k.bias",
self.lazy_load,
self.lazy_load_file,
), ),
) )
self.add_module( self.add_module(
...@@ -151,6 +195,8 @@ class WanCrossAttention(WeightModule): ...@@ -151,6 +195,8 @@ class WanCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"blocks.{self.block_index}.cross_attn.v.weight", f"blocks.{self.block_index}.cross_attn.v.weight",
f"blocks.{self.block_index}.cross_attn.v.bias", f"blocks.{self.block_index}.cross_attn.v.bias",
self.lazy_load,
self.lazy_load_file,
), ),
) )
self.add_module( self.add_module(
...@@ -158,15 +204,25 @@ class WanCrossAttention(WeightModule): ...@@ -158,15 +204,25 @@ class WanCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"blocks.{self.block_index}.cross_attn.o.weight", f"blocks.{self.block_index}.cross_attn.o.weight",
f"blocks.{self.block_index}.cross_attn.o.bias", f"blocks.{self.block_index}.cross_attn.o.bias",
self.lazy_load,
self.lazy_load_file,
), ),
) )
self.add_module( self.add_module(
"cross_attn_norm_q", "cross_attn_norm_q",
RMS_WEIGHT_REGISTER["sgl-kernel"](f"blocks.{self.block_index}.cross_attn.norm_q.weight"), RMS_WEIGHT_REGISTER["sgl-kernel"](
f"blocks.{self.block_index}.cross_attn.norm_q.weight",
self.lazy_load,
self.lazy_load_file,
),
) )
self.add_module( self.add_module(
"cross_attn_norm_k", "cross_attn_norm_k",
RMS_WEIGHT_REGISTER["sgl-kernel"](f"blocks.{self.block_index}.cross_attn.norm_k.weight"), RMS_WEIGHT_REGISTER["sgl-kernel"](
f"blocks.{self.block_index}.cross_attn.norm_k.weight",
self.lazy_load,
self.lazy_load_file,
),
) )
self.add_module("cross_attn_1", ATTN_WEIGHT_REGISTER[self.config["attention_type"]]()) self.add_module("cross_attn_1", ATTN_WEIGHT_REGISTER[self.config["attention_type"]]())
...@@ -176,6 +232,8 @@ class WanCrossAttention(WeightModule): ...@@ -176,6 +232,8 @@ class WanCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"blocks.{self.block_index}.cross_attn.k_img.weight", f"blocks.{self.block_index}.cross_attn.k_img.weight",
f"blocks.{self.block_index}.cross_attn.k_img.bias", f"blocks.{self.block_index}.cross_attn.k_img.bias",
self.lazy_load,
self.lazy_load_file,
), ),
) )
self.add_module( self.add_module(
...@@ -183,29 +241,39 @@ class WanCrossAttention(WeightModule): ...@@ -183,29 +241,39 @@ class WanCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"blocks.{self.block_index}.cross_attn.v_img.weight", f"blocks.{self.block_index}.cross_attn.v_img.weight",
f"blocks.{self.block_index}.cross_attn.v_img.bias", f"blocks.{self.block_index}.cross_attn.v_img.bias",
self.lazy_load,
self.lazy_load_file,
), ),
) )
self.add_module( self.add_module(
"cross_attn_norm_k_img", "cross_attn_norm_k_img",
RMS_WEIGHT_REGISTER["sgl-kernel"](f"blocks.{self.block_index}.cross_attn.norm_k_img.weight"), RMS_WEIGHT_REGISTER["sgl-kernel"](
f"blocks.{self.block_index}.cross_attn.norm_k_img.weight",
self.lazy_load,
self.lazy_load_file,
),
) )
self.add_module("cross_attn_2", ATTN_WEIGHT_REGISTER[self.config["attention_type"]]()) self.add_module("cross_attn_2", ATTN_WEIGHT_REGISTER[self.config["attention_type"]]())
class WanFFN(WeightModule): class WanFFN(WeightModule):
def __init__(self, block_index, task, mm_type, config): def __init__(self, block_index, task, mm_type, config, lazy_load, lazy_load_file):
super().__init__() super().__init__()
self.block_index = block_index self.block_index = block_index
self.mm_type = mm_type self.mm_type = mm_type
self.task = task self.task = task
self.config = config self.config = config
self.quant_method = config["mm_config"].get("quant_method", None) self.quant_method = config["mm_config"].get("quant_method", None)
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
self.add_module( self.add_module(
"ffn_0", "ffn_0",
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"blocks.{self.block_index}.ffn.0.weight", f"blocks.{self.block_index}.ffn.0.weight",
f"blocks.{self.block_index}.ffn.0.bias", f"blocks.{self.block_index}.ffn.0.bias",
self.lazy_load,
self.lazy_load_file,
), ),
) )
self.add_module( self.add_module(
...@@ -213,15 +281,25 @@ class WanFFN(WeightModule): ...@@ -213,15 +281,25 @@ class WanFFN(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"blocks.{self.block_index}.ffn.2.weight", f"blocks.{self.block_index}.ffn.2.weight",
f"blocks.{self.block_index}.ffn.2.bias", f"blocks.{self.block_index}.ffn.2.bias",
self.lazy_load,
self.lazy_load_file,
), ),
) )
if self.quant_method in ["smoothquant", "awq"]: if self.quant_method in ["smoothquant", "awq"]:
self.register_parameter( self.register_parameter(
"smooth_norm2_weight", "smooth_norm2_weight",
TENSOR_REGISTER["Default"](f"blocks.{self.block_index}.affine_norm2.weight"), TENSOR_REGISTER["Default"](
f"blocks.{self.block_index}.affine_norm3.weight",
self.lazy_load,
self.lazy_load_file,
),
) )
self.register_parameter( self.register_parameter(
"smooth_norm2_bias", "smooth_norm2_bias",
TENSOR_REGISTER["Default"](f"blocks.{self.block_index}.affine_norm2.bias"), TENSOR_REGISTER["Default"](
f"blocks.{self.block_index}.affine_norm3.bias",
self.lazy_load,
self.lazy_load_file,
),
) )
...@@ -26,6 +26,7 @@ class DefaultRunner: ...@@ -26,6 +26,7 @@ class DefaultRunner:
logger.warning("No prompt enhancer server available, disable prompt enhancer.") logger.warning("No prompt enhancer server available, disable prompt enhancer.")
def init_modules(self): def init_modules(self):
self.set_init_device()
if self.config["mode"] == "split_server": if self.config["mode"] == "split_server":
self.tensor_transporter = TensorTransporter() self.tensor_transporter = TensorTransporter()
self.image_transporter = ImageTransporter() self.image_transporter = ImageTransporter()
...@@ -45,6 +46,7 @@ class DefaultRunner: ...@@ -45,6 +46,7 @@ class DefaultRunner:
else: else:
self.run_input_encoder = self.run_input_encoder_server_t2v self.run_input_encoder = self.run_input_encoder_server_t2v
else: else:
if not self.config.get("lazy_load", False):
self.load_model() self.load_model()
self.run_dit = self.run_dit_local self.run_dit = self.run_dit_local
self.run_vae_decoder = self.run_vae_decoder_local self.run_vae_decoder = self.run_vae_decoder_local
...@@ -53,23 +55,21 @@ class DefaultRunner: ...@@ -53,23 +55,21 @@ class DefaultRunner:
else: else:
self.run_input_encoder = self.run_input_encoder_local_t2v self.run_input_encoder = self.run_input_encoder_local_t2v
def get_init_device(self): def set_init_device(self):
if self.config["parallel_attn_type"]: if self.config["parallel_attn_type"]:
cur_rank = dist.get_rank() cur_rank = dist.get_rank()
torch.cuda.set_device(cur_rank) torch.cuda.set_device(cur_rank)
if self.config.cpu_offload: if self.config.cpu_offload:
init_device = torch.device("cpu") self.init_device = torch.device("cpu")
else: else:
init_device = torch.device("cuda") self.init_device = torch.device("cuda")
return init_device
@ProfilingContext("Load models") @ProfilingContext("Load models")
def load_model(self): def load_model(self):
init_device = self.get_init_device() self.model = self.load_transformer()
self.text_encoders = self.load_text_encoder(init_device) self.text_encoders = self.load_text_encoder()
self.model = self.load_transformer(init_device) self.image_encoder = self.load_image_encoder()
self.image_encoder = self.load_image_encoder(init_device) self.vae_encoder, self.vae_decoder = self.load_vae()
self.vae_encoder, self.vae_decoder = self.load_vae(init_device)
def check_sub_servers(self, task_type): def check_sub_servers(self, task_type):
urls = self.config.get("sub_servers", {}).get(task_type, []) urls = self.config.get("sub_servers", {}).get(task_type, [])
...@@ -124,7 +124,10 @@ class DefaultRunner: ...@@ -124,7 +124,10 @@ class DefaultRunner:
def end_run(self): def end_run(self):
self.model.scheduler.clear() self.model.scheduler.clear()
del self.inputs, self.model.scheduler del self.inputs, self.model.scheduler
if self.config.get("lazy_load", False):
del self.model
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect()
@ProfilingContext("Run Encoders") @ProfilingContext("Run Encoders")
async def run_input_encoder_local_i2v(self): async def run_input_encoder_local_i2v(self):
...@@ -133,16 +136,22 @@ class DefaultRunner: ...@@ -133,16 +136,22 @@ class DefaultRunner:
clip_encoder_out = self.run_image_encoder(img) clip_encoder_out = self.run_image_encoder(img)
vae_encode_out, kwargs = self.run_vae_encoder(img) vae_encode_out, kwargs = self.run_vae_encoder(img)
text_encoder_output = self.run_text_encoder(prompt, img) text_encoder_output = self.run_text_encoder(prompt, img)
torch.cuda.empty_cache()
gc.collect()
return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output, img) return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output, img)
@ProfilingContext("Run Encoders") @ProfilingContext("Run Encoders")
async def run_input_encoder_local_t2v(self): async def run_input_encoder_local_t2v(self):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"] prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
text_encoder_output = self.run_text_encoder(prompt, None) text_encoder_output = self.run_text_encoder(prompt, None)
torch.cuda.empty_cache()
gc.collect()
return {"text_encoder_output": text_encoder_output, "image_encoder_output": None} return {"text_encoder_output": text_encoder_output, "image_encoder_output": None}
@ProfilingContext("Run DiT") @ProfilingContext("Run DiT")
async def run_dit_local(self, kwargs): async def run_dit_local(self, kwargs):
if self.config.get("lazy_load", False):
self.model = self.load_transformer()
self.init_scheduler() self.init_scheduler()
self.model.scheduler.prepare(self.inputs["image_encoder_output"]) self.model.scheduler.prepare(self.inputs["image_encoder_output"])
latents, generator = self.run() latents, generator = self.run()
...@@ -151,7 +160,12 @@ class DefaultRunner: ...@@ -151,7 +160,12 @@ class DefaultRunner:
@ProfilingContext("Run VAE Decoder") @ProfilingContext("Run VAE Decoder")
async def run_vae_decoder_local(self, latents, generator): async def run_vae_decoder_local(self, latents, generator):
if self.config.get("lazy_load", False):
self.vae_decoder = self.load_vae_decoder()
images = self.vae_decoder.decode(latents, generator=generator, config=self.config) images = self.vae_decoder.decode(latents, generator=generator, config=self.config)
if self.config.get("lazy_load", False):
torch.cuda.empty_cache()
gc.collect()
return images return images
@ProfilingContext("Save video") @ProfilingContext("Save video")
...@@ -228,12 +242,16 @@ class DefaultRunner: ...@@ -228,12 +242,16 @@ class DefaultRunner:
n_prompt = self.config.get("negative_prompt", "") n_prompt = self.config.get("negative_prompt", "")
img = Image.open(self.config["image_path"]).convert("RGB") img = Image.open(self.config["image_path"]).convert("RGB")
clip_encoder_out, vae_encode_out, text_encoder_output = await self.post_encoders_i2v(prompt, img, n_prompt) clip_encoder_out, vae_encode_out, text_encoder_output = await self.post_encoders_i2v(prompt, img, n_prompt)
torch.cuda.empty_cache()
gc.collect()
return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output, img) return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output, img)
async def run_input_encoder_server_t2v(self): async def run_input_encoder_server_t2v(self):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"] prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
n_prompt = self.config.get("negative_prompt", "") n_prompt = self.config.get("negative_prompt", "")
text_encoder_output = await self.post_encoders_t2v(prompt, n_prompt) text_encoder_output = await self.post_encoders_t2v(prompt, n_prompt)
torch.cuda.empty_cache()
gc.collect()
return {"text_encoder_output": text_encoder_output, "image_encoder_output": None} return {"text_encoder_output": text_encoder_output, "image_encoder_output": None}
async def run_dit_server(self, kwargs): async def run_dit_server(self, kwargs):
...@@ -265,5 +283,5 @@ class DefaultRunner: ...@@ -265,5 +283,5 @@ class DefaultRunner:
images = await self.run_vae_decoder(latents, generator) images = await self.run_vae_decoder(latents, generator)
self.save_video(images) self.save_video(images)
del latents, generator, images del latents, generator, images
gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect()
...@@ -21,23 +21,23 @@ class HunyuanRunner(DefaultRunner): ...@@ -21,23 +21,23 @@ class HunyuanRunner(DefaultRunner):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
def load_transformer(self, init_device): def load_transformer(self):
return HunyuanModel(self.config.model_path, self.config, init_device, self.config) return HunyuanModel(self.config.model_path, self.config, self.init_device, self.config)
def load_image_encoder(self, init_device): def load_image_encoder(self):
return None return None
def load_text_encoder(self, init_device): def load_text_encoder(self):
if self.config.task == "t2v": if self.config.task == "t2v":
text_encoder_1 = TextEncoderHFLlamaModel(os.path.join(self.config.model_path, "text_encoder"), init_device) text_encoder_1 = TextEncoderHFLlamaModel(os.path.join(self.config.model_path, "text_encoder"), self.init_device)
else: else:
text_encoder_1 = TextEncoderHFLlavaModel(os.path.join(self.config.model_path, "text_encoder_i2v"), init_device) text_encoder_1 = TextEncoderHFLlavaModel(os.path.join(self.config.model_path, "text_encoder_i2v"), self.init_device)
text_encoder_2 = TextEncoderHFClipModel(os.path.join(self.config.model_path, "text_encoder_2"), init_device) text_encoder_2 = TextEncoderHFClipModel(os.path.join(self.config.model_path, "text_encoder_2"), self.init_device)
text_encoders = [text_encoder_1, text_encoder_2] text_encoders = [text_encoder_1, text_encoder_2]
return text_encoders return text_encoders
def load_vae(self, init_device): def load_vae(self):
vae_model = VideoEncoderKLCausal3DModel(self.config.model_path, dtype=torch.float16, device=init_device, config=self.config) vae_model = VideoEncoderKLCausal3DModel(self.config.model_path, dtype=torch.float16, device=self.init_device, config=self.config)
return vae_model, vae_model return vae_model, vae_model
def init_scheduler(self): def init_scheduler(self):
......
import os import os
import gc
import numpy as np import numpy as np
import torch import torch
import torchvision.transforms.functional as TF import torchvision.transforms.functional as TF
...@@ -25,63 +26,83 @@ class WanRunner(DefaultRunner): ...@@ -25,63 +26,83 @@ class WanRunner(DefaultRunner):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
def load_transformer(self, init_device): def load_transformer(self):
model = WanModel(self.config.model_path, self.config, init_device) model = WanModel(
self.config.model_path,
self.config,
self.init_device,
)
if self.config.lora_path: if self.config.lora_path:
assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False)
lora_wrapper = WanLoraWrapper(model) lora_wrapper = WanLoraWrapper(model)
lora_name = lora_wrapper.load_lora(self.config.lora_path) lora_name = lora_wrapper.load_lora(self.config.lora_path)
lora_wrapper.apply_lora(lora_name, self.config.strength_model) lora_wrapper.apply_lora(lora_name, self.config.strength_model)
logger.info(f"Loaded LoRA: {lora_name}") logger.info(f"Loaded LoRA: {lora_name}")
return model return model
def load_image_encoder(self, init_device): def load_image_encoder(self):
image_encoder = None image_encoder = None
if self.config.task == "i2v": if self.config.task == "i2v":
image_encoder = CLIPModel( image_encoder = CLIPModel(
dtype=torch.float16, dtype=torch.float16,
device=init_device, device=self.init_device,
checkpoint_path=os.path.join( checkpoint_path=os.path.join(
self.config.model_path, self.config.model_path,
"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
), ),
tokenizer_path=os.path.join(self.config.model_path, "xlm-roberta-large"), clip_quantized=self.config.get("clip_quantized", False),
clip_quantized_ckpt=self.config.get("clip_quantized_ckpt", None),
quant_scheme=self.config.get("clip_quant_scheme", None),
) )
return image_encoder return image_encoder
def load_text_encoder(self, init_device): def load_text_encoder(self):
text_encoder = T5EncoderModel( text_encoder = T5EncoderModel(
text_len=self.config["text_len"], text_len=self.config["text_len"],
dtype=torch.bfloat16, dtype=torch.bfloat16,
device=init_device, device=self.init_device,
checkpoint_path=os.path.join(self.config.model_path, "models_t5_umt5-xxl-enc-bf16.pth"), checkpoint_path=os.path.join(self.config.model_path, "models_t5_umt5-xxl-enc-bf16.pth"),
tokenizer_path=os.path.join(self.config.model_path, "google/umt5-xxl"), tokenizer_path=os.path.join(self.config.model_path, "google/umt5-xxl"),
shard_fn=None, shard_fn=None,
cpu_offload=self.config.cpu_offload, cpu_offload=self.config.cpu_offload,
offload_granularity=self.config.get("text_encoder_offload_granularity", "model"), offload_granularity=self.config.get("t5_offload_granularity", "model"),
t5_quantized=self.config.get("t5_quantized", False),
t5_quantized_ckpt=self.config.get("t5_quantized_ckpt", None),
quant_scheme=self.config.get("t5_quant_scheme", None),
) )
text_encoders = [text_encoder] text_encoders = [text_encoder]
return text_encoders return text_encoders
def load_vae(self, init_device): def load_vae_encoder(self):
vae_config = {
"vae_pth": os.path.join(self.config.model_path, "Wan2.1_VAE.pth"),
"device": self.init_device,
"parallel": self.config.parallel_vae,
"use_tiling": self.config.get("use_tiling_vae", False),
}
if self.config.task != "i2v":
return None
else:
return WanVAE(**vae_config)
def load_vae_decoder(self):
vae_config = { vae_config = {
"vae_pth": os.path.join(self.config.model_path, "Wan2.1_VAE.pth"), "vae_pth": os.path.join(self.config.model_path, "Wan2.1_VAE.pth"),
"device": init_device, "device": self.init_device,
"parallel": self.config.parallel_vae, "parallel": self.config.parallel_vae,
"use_tiling": self.config.get("use_tiling_vae", False), "use_tiling": self.config.get("use_tiling_vae", False),
} }
use_tiny_decoder = self.config.get("tiny_vae", False) if self.config.get("tiny_vae", False):
is_i2v = self.config.task == "i2v"
if use_tiny_decoder:
vae_decoder = WanVAE_tiny( vae_decoder = WanVAE_tiny(
vae_pth=self.config.tiny_vae_path, vae_pth=self.config.tiny_vae_path,
device=init_device, device=self.init_device,
).to("cuda") ).to("cuda")
vae_encoder = WanVAE(**vae_config) if is_i2v else None
else: else:
vae_decoder = WanVAE(**vae_config) vae_decoder = WanVAE(**vae_config)
vae_encoder = vae_decoder if is_i2v else None return vae_decoder
return vae_encoder, vae_decoder def load_vae(self):
return self.load_vae_encoder(), self.load_vae_decoder()
def init_scheduler(self): def init_scheduler(self):
if self.config.feature_caching == "NoCaching": if self.config.feature_caching == "NoCaching":
...@@ -93,17 +114,29 @@ class WanRunner(DefaultRunner): ...@@ -93,17 +114,29 @@ class WanRunner(DefaultRunner):
self.model.set_scheduler(scheduler) self.model.set_scheduler(scheduler)
def run_text_encoder(self, text, img): def run_text_encoder(self, text, img):
if self.config.get("lazy_load", False):
self.text_encoders = self.load_text_encoder()
text_encoder_output = {} text_encoder_output = {}
n_prompt = self.config.get("negative_prompt", "") n_prompt = self.config.get("negative_prompt", "")
context = self.text_encoders[0].infer([text]) context = self.text_encoders[0].infer([text])
context_null = self.text_encoders[0].infer([n_prompt if n_prompt else ""]) context_null = self.text_encoders[0].infer([n_prompt if n_prompt else ""])
if self.config.get("lazy_load", False):
del self.text_encoders[0]
torch.cuda.empty_cache()
gc.collect()
text_encoder_output["context"] = context text_encoder_output["context"] = context
text_encoder_output["context_null"] = context_null text_encoder_output["context_null"] = context_null
return text_encoder_output return text_encoder_output
def run_image_encoder(self, img): def run_image_encoder(self, img):
if self.config.get("lazy_load", False):
self.image_encoder = self.load_image_encoder()
img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda() img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
clip_encoder_out = self.image_encoder.visual([img[:, None, :, :]], self.config).squeeze(0).to(torch.bfloat16) clip_encoder_out = self.image_encoder.visual([img[:, None, :, :]], self.config).squeeze(0).to(torch.bfloat16)
if self.config.get("lazy_load", False):
del self.image_encoder
torch.cuda.empty_cache()
gc.collect()
return clip_encoder_out return clip_encoder_out
def run_vae_encoder(self, img): def run_vae_encoder(self, img):
...@@ -120,11 +153,19 @@ class WanRunner(DefaultRunner): ...@@ -120,11 +153,19 @@ class WanRunner(DefaultRunner):
self.config.lat_h, kwargs["lat_h"] = lat_h, lat_h self.config.lat_h, kwargs["lat_h"] = lat_h, lat_h
self.config.lat_w, kwargs["lat_w"] = lat_w, lat_w self.config.lat_w, kwargs["lat_w"] = lat_w, lat_w
msk = torch.ones(1, self.config.target_video_length, lat_h, lat_w, device=torch.device("cuda")) msk = torch.ones(
1,
self.config.target_video_length,
lat_h,
lat_w,
device=torch.device("cuda"),
)
msk[:, 1:] = 0 msk[:, 1:] = 0
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
msk = msk.transpose(1, 2)[0] msk = msk.transpose(1, 2)[0]
if self.config.get("lazy_load", False):
self.vae_encoder = self.load_vae_encoder()
vae_encode_out = self.vae_encoder.encode( vae_encode_out = self.vae_encoder.encode(
[ [
torch.concat( torch.concat(
...@@ -137,12 +178,22 @@ class WanRunner(DefaultRunner): ...@@ -137,12 +178,22 @@ class WanRunner(DefaultRunner):
], ],
self.config, self.config,
)[0] )[0]
if self.config.get("lazy_load", False):
del self.vae_encoder
torch.cuda.empty_cache()
gc.collect()
vae_encode_out = torch.concat([msk, vae_encode_out]).to(torch.bfloat16) vae_encode_out = torch.concat([msk, vae_encode_out]).to(torch.bfloat16)
return vae_encode_out, kwargs return vae_encode_out, kwargs
def get_encoder_output_i2v(self, clip_encoder_out, vae_encode_out, text_encoder_output, img): def get_encoder_output_i2v(self, clip_encoder_out, vae_encode_out, text_encoder_output, img):
image_encoder_output = {"clip_encoder_out": clip_encoder_out, "vae_encode_out": vae_encode_out} image_encoder_output = {
return {"text_encoder_output": text_encoder_output, "image_encoder_output": image_encoder_output} "clip_encoder_out": clip_encoder_out,
"vae_encode_out": vae_encode_out,
}
return {
"text_encoder_output": text_encoder_output,
"image_encoder_output": image_encoder_output,
}
def set_target_shape(self): def set_target_shape(self):
ret = {} ret = {}
...@@ -167,4 +218,11 @@ class WanRunner(DefaultRunner): ...@@ -167,4 +218,11 @@ class WanRunner(DefaultRunner):
return ret return ret
def save_video_func(self, images): def save_video_func(self, images):
cache_video(tensor=images, save_file=self.config.save_video_path, fps=self.config.get("fps", 16), nrow=1, normalize=True, value_range=(-1, 1)) cache_video(
tensor=images,
save_file=self.config.save_video_path,
fps=self.config.get("fps", 16),
nrow=1,
normalize=True,
value_range=(-1, 1),
)
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
# =========================
# save quantization weight
# =========================
export RUNNING_FLAG=save_naive_quant
python -m lightx2v.infer \
--model_cls hunyuan \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/hunyuan_i2v_save_quant.json \
--prompt "An Asian man with short hair in black tactical uniform and white clothes waves a firework stick." \
--image_path ${lightx2v_path}/assets/inputs/imgs/img_1.jpg \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_hy_i2v.mp4
sleep 2
# =========================
# load quantization weight and inference
# =========================
export RUNNING_FLAG=infer
python -m lightx2v.infer \
--model_cls hunyuan \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/hunyuan_i2v_save_quant.json \
--prompt "An Asian man with short hair in black tactical uniform and white clothes waves a firework stick." \
--image_path ${lightx2v_path}/assets/inputs/imgs/img_1.jpg \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_hy_i2v.mp4
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
# =========================
# save quantization weight
# =========================
export RUNNING_FLAG=save_naive_quant
python -m lightx2v.infer \
--model_cls hunyuan \
--task t2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/hunyuan_t2v_save_quant.json \
--prompt "A cat walks on the grass, realistic style." \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_hy_t2v.mp4
sleep 2
# =========================
# load quantization weight and inference
# =========================
export RUNNING_FLAG=infer
python -m lightx2v.infer \
--model_cls hunyuan \
--task t2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/hunyuan_t2v_save_quant.json \
--prompt "A cat walks on the grass, realistic style." \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_hy_t2v.mp4
...@@ -7,7 +7,7 @@ model_path= ...@@ -7,7 +7,7 @@ model_path=
# check section # check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0 cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable." echo "Warn: CUDA_VISIBLE_DEVICES is not set, using defalt value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices} export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi fi
...@@ -28,37 +28,12 @@ export PYTHONPATH=${lightx2v_path}:$PYTHONPATH ...@@ -28,37 +28,12 @@ export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false export ENABLE_GRAPH_MODE=false
# =========================
# save quantization weight
# =========================
export RUNNING_FLAG=save_naive_quant
python -m lightx2v.infer \
--model_cls wan2.1 \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/wan_i2v_save_quant.json \
--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \
--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v.mp4
sleep 2
# =========================
# load quantization weight and inference
# =========================
export RUNNING_FLAG=infer
python -m lightx2v.infer \ python -m lightx2v.infer \
--model_cls wan2.1 \ --model_cls wan2.1 \
--task i2v \ --task i2v \
--model_path $model_path \ --model_path $model_path \
--config_json ${lightx2v_path}/configs/wan_i2v_save_quant.json \ --config_json ${lightx2v_path}/configs/quantization/wan_i2v_quant_auto.json \
--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \ --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \
--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \ --negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \ --image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v.mp4 --save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v_tea.mp4
...@@ -7,7 +7,7 @@ model_path= ...@@ -7,7 +7,7 @@ model_path=
# check section # check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0 cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable." echo "Warn: CUDA_VISIBLE_DEVICES is not set, using defalt value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices} export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi fi
...@@ -28,35 +28,12 @@ export PYTHONPATH=${lightx2v_path}:$PYTHONPATH ...@@ -28,35 +28,12 @@ export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false export ENABLE_GRAPH_MODE=false
# =========================
# save quantization weight
# =========================
export RUNNING_FLAG=save_naive_quant
python -m lightx2v.infer \
--model_cls wan2.1 \
--task t2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/wan_t2v_save_quant.json \
--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \
--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v.mp4
sleep 2
# =========================
# load quantization weight and inference
# =========================
export RUNNING_FLAG=infer
python -m lightx2v.infer \ python -m lightx2v.infer \
--model_cls wan2.1 \ --model_cls wan2.1 \
--task t2v \ --task i2v \
--model_path $model_path \ --model_path $model_path \
--config_json ${lightx2v_path}/configs/wan_t2v_save_quant.json \ --config_json ${lightx2v_path}/configs/quantization/wan_i2v_quant_offline.json \
--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \
--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \ --negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v.mp4 --image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v_tea.mp4
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
# =========================
# save quantization weight
# =========================
export RUNNING_FLAG=save_naive_quant
python -m lightx2v.infer \
--model_cls wan2.1_causvid \
--task t2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/wan_t2v_causvid_save_quant.json \
--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \
--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v_causvid.mp4
sleep 2
# =========================
# load quantization weight and inference
# =========================
export RUNNING_FLAG=infer
python -m lightx2v.infer \
--model_cls wan2.1_causvid \
--task t2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/wan_t2v_causvid_save_quant.json \
--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \
--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v_causvid.mp4
This diff is collapsed.
# 模型转换工具
A powerful utility for converting model weights between different formats and performing quantization tasks.
## Diffusers
Facilitates mutual conversion between diffusers architecture and lightx2v architecture
### Lightx2v->Diffusers
```bash
python converter.py \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P \
--output /Path/To/Wan2.1-I2V-14B-480P-Diffusers \
--direction forward
```
### Diffusers->Lightx2v
```bash
python converter.py \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P-Diffusers \
--output /Path/To/Wan2.1-I2V-14B-480P \
--direction backward
```
## Quantization
This tool supports converting fp32/fp16/bf16 model weights to INT8、FP8 type.
### Wan DIT
```bash
python converter.py \
--quantized \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/ \
--output /Path/To/output \
--output_ext .pth\
--output_name wan_int8 \
--dtype torch.int8 \
--model_type wan_dit
```
```bash
python converter.py \
--quantized \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/ \
--output /Path/To/output \
--output_ext .pth\
--output_name wan_fp8 \
--dtype torch.float8_e4m3_fn \
--model_type wan_dit
```
### Hunyuan DIT
```bash
python converter.py \
--quantized \
--source /Path/To/hunyuan/lightx2v_format/i2v/ \
--output /Path/To/output \
--output_ext .pth\
--output_name hunyuan_int8 \
--dtype torch.int8 \
--model_type hunyuan_dit
```
```bash
python converter.py \
--quantized \
--source /Path/To/hunyuan/lightx2v_format/i2v/ \
--output /Path/To/output \
--output_ext .pth\
--output_name hunyuan_fp8 \
--dtype torch.float8_e4m3_fn \
--model_type hunyuan_dit
```
### Wan T5EncoderModel
```bash
python converter.py \
--quantized \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/models_t5_umt5-xxl-enc-bf16.pth \
--output /Path/To/output \
--output_ext .pth\
--output_name models_t5_umt5-xxl-enc-int8 \
--dtype torch.int8 \
--model_type wan_t5
```
```bash
python converter.py \
--quantized \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/models_t5_umt5-xxl-enc-bf16.pth \
--output /Path/To/output \
--output_ext .pth\
--output_name models_t5_umt5-xxl-enc-fp8 \
--dtype torch.float8_e4m3fn \
--model_type wan_t5
```
### Wan CLIPModel
```bash
python converter.py \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth \
--quantized \
--output /Path/To/output \
--output_ext .pth \
--output_name clip_int8.pth \
--dtype torch.int8 \
--model_type wan_clip
```
```bash
python converter.py \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth \
--quantized \
--output /Path/To/output \
--output_ext .pth \
--output_name clip_fp8.pth \
--dtype torch.float8_e4m3fn \
--model_type wan_clip
```
# 模型转换工具
一款功能强大的实用工具,可在不同格式之间转换模型权重并执行量化任务。
## Diffusers
支持 Diffusers 架构与 LightX2V 架构之间的相互转换
### Lightx2v->Diffusers
```bash
python converter.py \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P \
--output /Path/To/Wan2.1-I2V-14B-480P-Diffusers \
--direction forward
```
### Diffusers->Lightx2v
```bash
python converter.py \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P-Diffusers \
--output /Path/To/Wan2.1-I2V-14B-480P \
--direction backward
```
## 量化
该工具支持将 **FP32/FP16/BF16** 模型权重转换为 **INT8、FP8** 类型。
### Wan DIT
```bash
python converter.py \
--quantized \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/ \
--output /Path/To/output \
--output_ext .pth\
--output_name wan_int8 \
--dtype torch.int8 \
--model_type wan_dit
```
```bash
python converter.py \
--quantized \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/ \
--output /Path/To/output \
--output_ext .pth\
--output_name wan_fp8 \
--dtype torch.float8_e4m3_fn \
--model_type wan_dit
```
### Hunyuan DIT
```bash
python converter.py \
--quantized \
--source /Path/To/hunyuan/lightx2v_format/i2v/ \
--output /Path/To/output \
--output_ext .pth\
--output_name hunyuan_int8 \
--dtype torch.int8 \
--model_type hunyuan_dit
```
```bash
python converter.py \
--quantized \
--source /Path/To/hunyuan/lightx2v_format/i2v/ \
--output /Path/To/output \
--output_ext .pth\
--output_name hunyuan_fp8 \
--dtype torch.float8_e4m3_fn \
--model_type hunyuan_dit
```
### Wan T5EncoderModel
```bash
python converter.py \
--quantized \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/models_t5_umt5-xxl-enc-bf16.pth \
--output /Path/To/output \
--output_ext .pth\
--output_name models_t5_umt5-xxl-enc-int8 \
--dtype torch.int8 \
--model_type wan_t5
```
```bash
python converter.py \
--quantized \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/models_t5_umt5-xxl-enc-bf16.pth \
--output /Path/To/output \
--output_ext .pth\
--output_name models_t5_umt5-xxl-enc-fp8 \
--dtype torch.float8_e4m3fn \
--model_type wan_t5
```
### Wan CLIPModel
```bash
python converter.py \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth \
--quantized \
--output /Path/To/output \
--output_ext .pth \
--output_name clip_int8.pth \
--dtype torch.int8 \
--model_type wan_clip
```
```bash
python converter.py \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth \
--quantized \
--output /Path/To/output \
--output_ext .pth \
--output_name clip_fp8.pth \
--dtype torch.float8_e4m3fn \
--model_type wan_clip
```
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment