Commit a50bcc53 authored by Dongz's avatar Dongz Committed by Yang Yong(雍洋)
Browse files

add lint feature and minor fix (#7)

* [minor]: optimize dockerfile for fewer layer

* [feature]: add pre-commit lint, update readme for contribution guidance

* [minor]: fix run shell privileges

* [auto]: first lint without rule F, fix rule E

* [minor]: fix docker file error
parent 3b460075
......@@ -39,28 +39,15 @@ def init_weights(m):
nn.init.normal_(m.v.weight, std=m.dim**-0.5)
nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn) ** -0.5)
elif isinstance(m, T5RelativeEmbedding):
nn.init.normal_(
m.embedding.weight, std=(2 * m.num_buckets * m.num_heads) ** -0.5
)
nn.init.normal_(m.embedding.weight, std=(2 * m.num_buckets * m.num_heads) ** -0.5)
class GELU(nn.Module):
def forward(self, x):
return (
0.5
* x
* (
1.0
+ torch.tanh(
math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))
)
)
)
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
class T5LayerNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super(T5LayerNorm, self).__init__()
self.dim = dim
......@@ -75,7 +62,6 @@ class T5LayerNorm(nn.Module):
class T5Attention(nn.Module):
def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
assert dim_attn % num_heads == 0
super(T5Attention, self).__init__()
......@@ -128,7 +114,6 @@ class T5Attention(nn.Module):
class T5FeedForward(nn.Module):
def __init__(self, dim, dim_ffn, dropout=0.1):
super(T5FeedForward, self).__init__()
self.dim = dim
......@@ -149,7 +134,6 @@ class T5FeedForward(nn.Module):
class T5SelfAttention(nn.Module):
def __init__(
self,
dim,
......@@ -173,11 +157,7 @@ class T5SelfAttention(nn.Module):
self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
self.norm2 = T5LayerNorm(dim)
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
self.pos_embedding = (
None
if shared_pos
else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
)
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
def forward(self, x, mask=None, pos_bias=None):
e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
......@@ -187,7 +167,6 @@ class T5SelfAttention(nn.Module):
class T5CrossAttention(nn.Module):
def __init__(
self,
dim,
......@@ -213,27 +192,17 @@ class T5CrossAttention(nn.Module):
self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
self.norm3 = T5LayerNorm(dim)
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
self.pos_embedding = (
None
if shared_pos
else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False)
)
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False)
def forward(
self, x, mask=None, encoder_states=None, encoder_mask=None, pos_bias=None
):
def forward(self, x, mask=None, encoder_states=None, encoder_mask=None, pos_bias=None):
e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
x = fp16_clamp(
x
+ self.cross_attn(self.norm2(x), context=encoder_states, mask=encoder_mask)
)
x = fp16_clamp(x + self.cross_attn(self.norm2(x), context=encoder_states, mask=encoder_mask))
x = fp16_clamp(x + self.ffn(self.norm3(x)))
return x
class T5RelativeEmbedding(nn.Module):
def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
super(T5RelativeEmbedding, self).__init__()
self.num_buckets = num_buckets
......@@ -248,9 +217,7 @@ class T5RelativeEmbedding(nn.Module):
device = self.embedding.weight.device
# rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
# torch.arange(lq).unsqueeze(1).to(device)
rel_pos = torch.arange(lk, device=device).unsqueeze(0) - torch.arange(
lq, device=device
).unsqueeze(1)
rel_pos = torch.arange(lk, device=device).unsqueeze(0) - torch.arange(lq, device=device).unsqueeze(1)
rel_pos = self._relative_position_bucket(rel_pos)
rel_pos_embeds = self.embedding(rel_pos)
rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(0) # [1, N, Lq, Lk]
......@@ -269,23 +236,13 @@ class T5RelativeEmbedding(nn.Module):
# embeddings for small and large positions
max_exact = num_buckets // 2
rel_pos_large = (
max_exact
+ (
torch.log(rel_pos.float() / max_exact)
/ math.log(self.max_dist / max_exact)
* (num_buckets - max_exact)
).long()
)
rel_pos_large = torch.min(
rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1)
)
rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) / math.log(self.max_dist / max_exact) * (num_buckets - max_exact)).long()
rel_pos_large = torch.min(rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
return rel_buckets
class T5Encoder(nn.Module):
def __init__(
self,
vocab,
......@@ -308,23 +265,10 @@ class T5Encoder(nn.Module):
self.shared_pos = shared_pos
# layers
self.token_embedding = (
vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
)
self.pos_embedding = (
T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
if shared_pos
else None
)
self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True) if shared_pos else None
self.dropout = nn.Dropout(dropout)
self.blocks = nn.ModuleList(
[
T5SelfAttention(
dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout
)
for _ in range(num_layers)
]
)
self.blocks = nn.ModuleList([T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout) for _ in range(num_layers)])
self.norm = T5LayerNorm(dim)
# initialize weights
......@@ -342,7 +286,6 @@ class T5Encoder(nn.Module):
class T5Decoder(nn.Module):
def __init__(
self,
vocab,
......@@ -365,23 +308,10 @@ class T5Decoder(nn.Module):
self.shared_pos = shared_pos
# layers
self.token_embedding = (
vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
)
self.pos_embedding = (
T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False)
if shared_pos
else None
)
self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False) if shared_pos else None
self.dropout = nn.Dropout(dropout)
self.blocks = nn.ModuleList(
[
T5CrossAttention(
dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout
)
for _ in range(num_layers)
]
)
self.blocks = nn.ModuleList([T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout) for _ in range(num_layers)])
self.norm = T5LayerNorm(dim)
# initialize weights
......@@ -408,7 +338,6 @@ class T5Decoder(nn.Module):
class T5Model(nn.Module):
def __init__(
self,
vocab_size,
......@@ -530,7 +459,6 @@ def umt5_xxl(**kwargs):
class T5EncoderModel:
def __init__(
self,
text_len,
......@@ -547,13 +475,7 @@ class T5EncoderModel:
self.tokenizer_path = tokenizer_path
# init model
model = (
umt5_xxl(
encoder_only=True, return_tokenizer=False, dtype=dtype, device=device
)
.eval()
.requires_grad_(False)
)
model = umt5_xxl(encoder_only=True, return_tokenizer=False, dtype=dtype, device=device).eval().requires_grad_(False)
logging.info(f"loading {checkpoint_path}")
model.load_state_dict(torch.load(checkpoint_path, map_location="cpu", weights_only=True))
self.model = model
......@@ -562,9 +484,7 @@ class T5EncoderModel:
else:
self.model.to(self.device)
# init tokenizer
self.tokenizer = HuggingfaceTokenizer(
name=tokenizer_path, seq_len=text_len, clean="whitespace"
)
self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=text_len, clean="whitespace")
def to_cpu(self):
self.model = self.model.to("cpu")
......
......@@ -24,10 +24,7 @@ def whitespace_clean(text):
def canonicalize(text, keep_punctuation_exact_string=None):
text = text.replace("_", " ")
if keep_punctuation_exact_string:
text = keep_punctuation_exact_string.join(
part.translate(str.maketrans("", "", string.punctuation))
for part in text.split(keep_punctuation_exact_string)
)
text = keep_punctuation_exact_string.join(part.translate(str.maketrans("", "", string.punctuation)) for part in text.split(keep_punctuation_exact_string))
else:
text = text.translate(str.maketrans("", "", string.punctuation))
text = text.lower()
......@@ -36,7 +33,6 @@ def canonicalize(text, keep_punctuation_exact_string=None):
class HuggingfaceTokenizer:
def __init__(self, name, seq_len=None, clean=None, **kwargs):
assert clean in (None, "whitespace", "lower", "canonicalize")
self.name = name
......
......@@ -3,7 +3,7 @@ import torch
from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D
class VideoEncoderKLCausal3DModel():
class VideoEncoderKLCausal3DModel:
def __init__(self, model_path, dtype, device):
self.model_path = model_path
self.dtype = dtype
......@@ -11,10 +11,10 @@ class VideoEncoderKLCausal3DModel():
self.load()
def load(self):
self.vae_path = os.path.join(self.model_path, 'hunyuan-video-t2v-720p/vae')
self.vae_path = os.path.join(self.model_path, "hunyuan-video-t2v-720p/vae")
config = AutoencoderKLCausal3D.load_config(self.vae_path)
self.model = AutoencoderKLCausal3D.from_config(config)
ckpt = torch.load(os.path.join(self.vae_path, 'pytorch_model.pt'), map_location='cpu', weights_only=True)
ckpt = torch.load(os.path.join(self.vae_path, "pytorch_model.pt"), map_location="cpu", weights_only=True)
self.model.load_state_dict(ckpt)
self.model = self.model.to(dtype=self.dtype, device=self.device)
self.model.requires_grad_(False)
......@@ -32,14 +32,13 @@ class VideoEncoderKLCausal3DModel():
latents = latents / self.model.config.scaling_factor
latents = latents.to(dtype=self.dtype, device=torch.device("cuda"))
self.model.enable_tiling()
image = self.model.decode(
latents, return_dict=False, generator=generator
)[0]
image = self.model.decode(latents, return_dict=False, generator=generator)[0]
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().float()
if args.cpu_offload:
self.to_cpu()
return image
if __name__ == "__main__":
vae_model = VideoEncoderKLCausal3DModel("/mnt/nvme0/yongyang/projects/hy/new/HunyuanVideo/ckpts", dtype=torch.float16, device=torch.device("cuda"))
......@@ -44,7 +44,6 @@ class CausalConv3d(nn.Conv3d):
class RMS_norm(nn.Module):
def __init__(self, dim, channel_first=True, images=True, bias=False):
super().__init__()
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
......@@ -56,16 +55,10 @@ class RMS_norm(nn.Module):
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
def forward(self, x):
return (
F.normalize(x, dim=(1 if self.channel_first else -1))
* self.scale
* self.gamma
+ self.bias
)
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
class Upsample(nn.Upsample):
def forward(self, x):
"""
Fix bfloat16 support for nearest neighbor interpolation.
......@@ -74,7 +67,6 @@ class Upsample(nn.Upsample):
class Resample(nn.Module):
def __init__(self, dim, mode):
assert mode in (
"none",
......@@ -101,16 +93,10 @@ class Resample(nn.Module):
self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
elif mode == "downsample2d":
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))
)
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
elif mode == "downsample3d":
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))
)
self.time_conv = CausalConv3d(
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)
)
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
else:
self.resample = nn.Identity()
......@@ -124,28 +110,17 @@ class Resample(nn.Module):
feat_cache[idx] = "Rep"
feat_idx[0] += 1
else:
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if (
cache_x.shape[2] < 2
and feat_cache[idx] is not None
and feat_cache[idx] != "Rep"
):
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
# cache last frame of last two chunk
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :]
.unsqueeze(2)
.to(cache_x.device),
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
cache_x,
],
dim=2,
)
if (
cache_x.shape[2] < 2
and feat_cache[idx] is not None
and feat_cache[idx] == "Rep"
):
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep":
cache_x = torch.cat(
[torch.zeros_like(cache_x).to(cache_x.device), cache_x],
dim=2,
......@@ -172,15 +147,12 @@ class Resample(nn.Module):
feat_cache[idx] = x.clone()
feat_idx[0] += 1
else:
cache_x = x[:, :, -1:, :, :].clone()
# if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
# # cache last frame of last two chunk
# cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.time_conv(
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)
)
x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
feat_cache[idx] = cache_x
feat_idx[0] += 1
return x
......@@ -210,7 +182,6 @@ class Resample(nn.Module):
class ResidualBlock(nn.Module):
def __init__(self, in_dim, out_dim, dropout=0.0):
super().__init__()
self.in_dim = in_dim
......@@ -226,9 +197,7 @@ class ResidualBlock(nn.Module):
nn.Dropout(dropout),
CausalConv3d(out_dim, out_dim, 3, padding=1),
)
self.shortcut = (
CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
)
self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
h = self.shortcut(x)
......@@ -240,9 +209,7 @@ class ResidualBlock(nn.Module):
# cache last frame of last two chunk
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :]
.unsqueeze(2)
.to(cache_x.device),
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
cache_x,
],
dim=2,
......@@ -278,13 +245,7 @@ class AttentionBlock(nn.Module):
x = rearrange(x, "b c t h w -> (b t) c h w")
x = self.norm(x)
# compute query, key, value
q, k, v = (
self.to_qkv(x)
.reshape(b * t, 1, c * 3, -1)
.permute(0, 1, 3, 2)
.contiguous()
.chunk(3, dim=-1)
)
q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(0, 1, 3, 2).contiguous().chunk(3, dim=-1)
# apply attention
x = F.scaled_dot_product_attention(
......@@ -301,7 +262,6 @@ class AttentionBlock(nn.Module):
class Encoder3d(nn.Module):
def __init__(
self,
dim=128,
......@@ -400,9 +360,7 @@ class Encoder3d(nn.Module):
# cache last frame of last two chunk
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :]
.unsqueeze(2)
.to(cache_x.device),
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
cache_x,
],
dim=2,
......@@ -416,7 +374,6 @@ class Encoder3d(nn.Module):
class Decoder3d(nn.Module):
def __init__(
self,
dim=128,
......@@ -518,9 +475,7 @@ class Decoder3d(nn.Module):
# cache last frame of last two chunk
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :]
.unsqueeze(2)
.to(cache_x.device),
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
cache_x,
],
dim=2,
......@@ -542,7 +497,6 @@ def count_conv3d(model):
class WanVAE_(nn.Module):
def __init__(
self,
dim=128,
......@@ -613,9 +567,7 @@ class WanVAE_(nn.Module):
out = torch.cat([out, out_], 2)
mu, log_var = self.conv1(out).chunk(2, dim=1)
if isinstance(scale[0], torch.Tensor):
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
1, self.z_dim, 1, 1, 1
)
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1)
else:
mu = (mu - scale[0]) * scale[1]
self.clear_cache()
......@@ -625,9 +577,7 @@ class WanVAE_(nn.Module):
self.clear_cache()
# z: [b,c,t,h,w]
if isinstance(scale[0], torch.Tensor):
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
1, self.z_dim, 1, 1, 1
)
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1)
else:
z = z / scale[1] + scale[0]
iter_ = z.shape[2]
......@@ -700,7 +650,6 @@ def _video_vae(pretrained_path=None, z_dim=None, device="cpu", **kwargs):
class WanVAE:
def __init__(
self,
z_dim=16,
......@@ -780,11 +729,8 @@ class WanVAE:
"""
videos: A list of videos each with shape [C, T, H, W].
"""
return [
self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0)
for u in videos
]
return [self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0) for u in videos]
def decode_dist(self, zs, world_size, cur_rank, split_dim):
splited_total_len = zs.shape[split_dim]
splited_chunk_len = splited_total_len // world_size
......@@ -792,37 +738,37 @@ class WanVAE:
if cur_rank == 0:
if split_dim == 2:
zs = zs[:,:,:splited_chunk_len+2*padding_size,:].contiguous()
zs = zs[:, :, : splited_chunk_len + 2 * padding_size, :].contiguous()
elif split_dim == 3:
zs = zs[:,:,:,:splited_chunk_len+2*padding_size].contiguous()
elif cur_rank == world_size-1:
zs = zs[:, :, :, : splited_chunk_len + 2 * padding_size].contiguous()
elif cur_rank == world_size - 1:
if split_dim == 2:
zs = zs[:,:,-(splited_chunk_len+2*padding_size):,:].contiguous()
zs = zs[:, :, -(splited_chunk_len + 2 * padding_size) :, :].contiguous()
elif split_dim == 3:
zs = zs[:,:,:,-(splited_chunk_len+2*padding_size):].contiguous()
zs = zs[:, :, :, -(splited_chunk_len + 2 * padding_size) :].contiguous()
else:
if split_dim == 2:
zs = zs[:,:,cur_rank*splited_chunk_len-padding_size:(cur_rank+1)*splited_chunk_len+padding_size,:].contiguous()
zs = zs[:, :, cur_rank * splited_chunk_len - padding_size : (cur_rank + 1) * splited_chunk_len + padding_size, :].contiguous()
elif split_dim == 3:
zs = zs[:,:,:,cur_rank*splited_chunk_len-padding_size:(cur_rank+1)*splited_chunk_len+padding_size].contiguous()
zs = zs[:, :, :, cur_rank * splited_chunk_len - padding_size : (cur_rank + 1) * splited_chunk_len + padding_size].contiguous()
images = self.model.decode(zs.unsqueeze(0), self.scale).float().clamp_(-1, 1)
if cur_rank == 0:
if split_dim == 2:
images = images[:,:,:,:splited_chunk_len*8,:].contiguous()
images = images[:, :, :, : splited_chunk_len * 8, :].contiguous()
elif split_dim == 3:
images = images[:,:,:,:,:splited_chunk_len*8].contiguous()
elif cur_rank == world_size-1:
images = images[:, :, :, :, : splited_chunk_len * 8].contiguous()
elif cur_rank == world_size - 1:
if split_dim == 2:
images = images[:,:,:,-splited_chunk_len*8:,:].contiguous()
images = images[:, :, :, -splited_chunk_len * 8 :, :].contiguous()
elif split_dim == 3:
images = images[:,:,:,:,-splited_chunk_len*8:].contiguous()
images = images[:, :, :, :, -splited_chunk_len * 8 :].contiguous()
else:
if split_dim == 2:
images = images[:,:,:,8*padding_size:-8*padding_size,:].contiguous()
images = images[:, :, :, 8 * padding_size : -8 * padding_size, :].contiguous()
elif split_dim == 3:
images = images[:,:,:,:,8*padding_size:-8*padding_size].contiguous()
images = images[:, :, :, :, 8 * padding_size : -8 * padding_size].contiguous()
full_images = [torch.empty_like(images) for _ in range(world_size)]
dist.all_gather(full_images, images)
......@@ -832,7 +778,6 @@ class WanVAE:
images = torch.cat(full_images, dim=-1)
return images
def decode(self, zs, generator, args):
if args.cpu_offload:
......
......@@ -5,7 +5,7 @@ from lightx2v.text2v.models.video_encoders.hf.autoencoder_kl_causal_3d.autoencod
from lightx2v.text2v.models.video_encoders.trt.autoencoder_kl_causal_3d import trt_vae_infer
class VideoEncoderKLCausal3DModel():
class VideoEncoderKLCausal3DModel:
def __init__(self, model_path, dtype, device):
self.model_path = model_path
self.dtype = dtype
......@@ -13,10 +13,10 @@ class VideoEncoderKLCausal3DModel():
self.load()
def load(self):
self.vae_path = os.path.join(self.model_path, 'hunyuan-video-t2v-720p/vae')
self.vae_path = os.path.join(self.model_path, "hunyuan-video-t2v-720p/vae")
config = AutoencoderKLCausal3D.load_config(self.vae_path)
self.model = AutoencoderKLCausal3D.from_config(config)
ckpt = torch.load(os.path.join(self.vae_path, 'pytorch_model.pt'), map_location='cpu', weights_only=True)
ckpt = torch.load(os.path.join(self.vae_path, "pytorch_model.pt"), map_location="cpu", weights_only=True)
self.model.load_state_dict(ckpt)
self.model = self.model.to(dtype=self.dtype, device=self.device)
self.model.requires_grad_(False)
......@@ -28,12 +28,11 @@ class VideoEncoderKLCausal3DModel():
latents = latents / self.model.config.scaling_factor
latents = latents.to(dtype=self.dtype, device=self.device)
self.model.enable_tiling()
image = self.model.decode(
latents, return_dict=False, generator=generator
)[0]
image = self.model.decode(latents, return_dict=False, generator=generator)[0]
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().float()
return image
if __name__ == "__main__":
vae_model = VideoEncoderKLCausal3DModel("/mnt/nvme1/yongyang/models/hy/ckpts", dtype=torch.float16, device=torch.device("cuda"))
......@@ -100,18 +100,18 @@ class HyVaeTrtModelInfer(nn.Module):
device = batch.device
dtype = batch.dtype
batch = batch.cpu().numpy()
def get_output_shape(shp):
b, c, t, h, w = shp
out = (b, 3, 4*(t-1)+1, h*8, w*8)
out = (b, 3, 4 * (t - 1) + 1, h * 8, w * 8)
return out
shp_dict = {"inp": batch.shape, "out": get_output_shape(batch.shape)}
self.alloc(shp_dict)
output = np.zeros(*self.output_spec())
# Process I/O and execute the network
common.memcpy_host_to_device(
self.inputs[0]["allocation"], np.ascontiguousarray(batch)
)
common.memcpy_host_to_device(self.inputs[0]["allocation"], np.ascontiguousarray(batch))
self.context.execute_v2(self.allocations)
common.memcpy_device_to_host(output, self.outputs[0]["allocation"])
output = torch.from_numpy(output).to(device).type(dtype)
......@@ -122,19 +122,16 @@ class HyVaeTrtModelInfer(nn.Module):
logger.info("Start to do VAE onnx exporting.")
device = next(decoder.parameters())[0].device
example_inp = torch.rand(1, 16, 17, 32, 32).to(device).type(next(decoder.parameters())[0].dtype)
out_path = str(Path(str(model_dir))/"vae_decoder.onnx")
out_path = str(Path(str(model_dir)) / "vae_decoder.onnx")
torch.onnx.export(
decoder.eval().half(),
example_inp.half(),
out_path,
input_names=['inp'],
output_names=['out'],
input_names=["inp"],
output_names=["out"],
opset_version=14,
dynamic_axes={
"inp": {1: "c1", 2: "c2", 3: "c3", 4: "c4"},
"out": {1: "c1", 2: "c2", 3: "c3", 4: "c4"}
}
)
dynamic_axes={"inp": {1: "c1", 2: "c2", 3: "c3", 4: "c4"}, "out": {1: "c1", 2: "c2", 3: "c3", 4: "c4"}},
)
# onnx_ori = onnx.load(out_path)
os.system(f"onnxsim {out_path} {out_path}")
# onnx_opt, check = simplify(onnx_ori)
......@@ -163,4 +160,4 @@ class HyVaeTrtModelInfer(nn.Module):
if not Path(engine_path).exists():
raise RuntimeError(f"Convert vae onnx({onnx_path}) to tensorrt engine failed.")
logger.info("Finish VAE tensorrt converting.")
return engine_path
\ No newline at end of file
return engine_path
......@@ -8,20 +8,20 @@ class BaseQuantizer(object):
self.sym = symmetric
self.granularity = granularity
self.kwargs = kwargs
if self.granularity == 'per_group':
self.group_size = self.kwargs['group_size']
self.calib_algo = self.kwargs.get('calib_algo', 'minmax')
if self.granularity == "per_group":
self.group_size = self.kwargs["group_size"]
self.calib_algo = self.kwargs.get("calib_algo", "minmax")
def get_tensor_range(self, tensor):
if self.calib_algo == 'minmax':
if self.calib_algo == "minmax":
return self.get_minmax_range(tensor)
elif self.calib_algo == 'mse':
elif self.calib_algo == "mse":
return self.get_mse_range(tensor)
else:
raise ValueError(f'Unsupported calibration algorithm: {self.calib_algo}')
raise ValueError(f"Unsupported calibration algorithm: {self.calib_algo}")
def get_minmax_range(self, tensor):
if self.granularity == 'per_tensor':
if self.granularity == "per_tensor":
max_val = torch.max(tensor)
min_val = torch.min(tensor)
else:
......@@ -47,7 +47,7 @@ class BaseQuantizer(object):
return scales, zeros, qmax, qmin
def reshape_tensor(self, tensor, allow_padding=False):
if self.granularity == 'per_group':
if self.granularity == "per_group":
t = tensor.reshape(-1, self.group_size)
else:
t = tensor
......@@ -79,7 +79,7 @@ class BaseQuantizer(object):
tensor, scales, zeros, qmax, qmin = self.get_tensor_qparams(tensor)
tensor = self.quant(tensor, scales, zeros, qmax, qmin)
tensor = self.restore_tensor(tensor, org_shape)
if self.sym == True:
if self.sym:
zeros = None
return tensor, scales, zeros
......@@ -87,9 +87,9 @@ class BaseQuantizer(object):
class IntegerQuantizer(BaseQuantizer):
def __init__(self, bit, symmetric, granularity, **kwargs):
super().__init__(bit, symmetric, granularity, **kwargs)
if 'int_range' in self.kwargs:
self.qmin = self.kwargs['int_range'][0]
self.qmax = self.kwargs['int_range'][1]
if "int_range" in self.kwargs:
self.qmin = self.kwargs["int_range"][0]
self.qmax = self.kwargs["int_range"][1]
else:
if self.sym:
self.qmin = -(2 ** (self.bit - 1))
......@@ -110,7 +110,14 @@ class IntegerQuantizer(BaseQuantizer):
tensor = (tensor - zeros) * scales
return tensor
def quant_dequant(self, tensor, scales, zeros, qmax, qmin,):
def quant_dequant(
self,
tensor,
scales,
zeros,
qmax,
qmin,
):
tensor = self.quant(tensor, scales, zeros, qmax, qmin)
tensor = self.dequant(tensor, scales, zeros)
return tensor
......@@ -119,19 +126,19 @@ class IntegerQuantizer(BaseQuantizer):
class FloatQuantizer(BaseQuantizer):
def __init__(self, bit, symmetric, granularity, **kwargs):
super().__init__(bit, symmetric, granularity, **kwargs)
assert self.bit in ['e4m3', 'e5m2'], f'Unsupported bit configuration: {self.bit}'
assert self.sym == True
if self.bit == 'e4m3':
assert self.bit in ["e4m3", "e5m2"], f"Unsupported bit configuration: {self.bit}"
assert self.sym
if self.bit == "e4m3":
self.e_bits = 4
self.m_bits = 3
self.fp_dtype = torch.float8_e4m3fn
elif self.bit == 'e5m2':
elif self.bit == "e5m2":
self.e_bits = 5
self.m_bits = 2
self.fp_dtype = torch.float8_e5m2
else:
raise ValueError(f'Unsupported bit configuration: {self.bit}')
raise ValueError(f"Unsupported bit configuration: {self.bit}")
finfo = torch.finfo(self.fp_dtype)
self.qmin, self.qmax = finfo.min, finfo.max
......@@ -141,13 +148,9 @@ class FloatQuantizer(BaseQuantizer):
def quant(self, tensor, scales, zeros, qmax, qmin):
scaled_tensor = tensor / scales + zeros
scaled_tensor = torch.clip(
scaled_tensor, self.qmin.cuda(), self.qmax.cuda()
)
scaled_tensor = torch.clip(scaled_tensor, self.qmin.cuda(), self.qmax.cuda())
org_dtype = scaled_tensor.dtype
q_tensor = float_quantize(
scaled_tensor.float(), self.e_bits, self.m_bits, rounding='nearest'
)
q_tensor = float_quantize(scaled_tensor.float(), self.e_bits, self.m_bits, rounding="nearest")
q_tensor.to(org_dtype)
return q_tensor
......@@ -161,21 +164,21 @@ class FloatQuantizer(BaseQuantizer):
return tensor
if __name__ == '__main__':
if __name__ == "__main__":
weight = torch.randn(4096, 4096, dtype=torch.bfloat16).cuda()
quantizer = IntegerQuantizer(4, False, 'per_group', group_size=128)
quantizer = IntegerQuantizer(4, False, "per_group", group_size=128)
q_weight = quantizer.fake_quant_tensor(weight)
print(weight)
print(q_weight)
print(f"cosine = {torch.cosine_similarity(weight.view(1, -1).to(torch.float64), q_weight.view(1, -1).to(torch.float64))}")
realq_weight, scales, zeros = quantizer.real_quant_tensor(weight)
print(f"realq_weight = {realq_weight}, {realq_weight.shape}")
print(f"scales = {scales}, {scales.shape}")
print(f"zeros = {zeros}, {zeros.shape}")
weight = torch.randn(8192, 4096, dtype=torch.bfloat16).cuda()
quantizer = FloatQuantizer('e4m3', True, 'per_channel')
quantizer = FloatQuantizer("e4m3", True, "per_channel")
q_weight = quantizer.fake_quant_tensor(weight)
print(weight)
print(q_weight)
......
......@@ -11,14 +11,14 @@ class Register(dict):
def register(self, target, key=None):
if not callable(target):
raise Exception(f'Error: {target} must be callable!')
raise Exception(f"Error: {target} must be callable!")
if key is None:
key = target.__name__
if key in self._dict:
raise Exception(f'{key} already exists.')
raise Exception(f"{key} already exists.")
self[key] = target
return target
......
......@@ -66,12 +66,7 @@ def cache_video(
# preprocess
tensor = tensor.clamp(min(value_range), max(value_range))
tensor = torch.stack(
[
torchvision.utils.make_grid(
u, nrow=nrow, normalize=normalize, value_range=value_range
)
for u in tensor.unbind(2)
],
[torchvision.utils.make_grid(u, nrow=nrow, normalize=normalize, value_range=value_range) for u in tensor.unbind(2)],
dim=1,
).permute(1, 2, 3, 0)
tensor = (tensor * 255).type(torch.uint8).cpu()
......
[tool.ruff]
exclude = [".git", ".mypy_cache", ".ruff_cache", ".venv", "dist"]
target-version = "py311"
line-length = 200
indent-width = 4
lint.ignore =["F"]
[tool.ruff.format]
line-ending = "lf"
quote-style = "double"
indent-style = "space"
......@@ -18,4 +18,4 @@ python ${lightx2v_path}/lightx2v/__main__.py \
--target_width 1280 \
--attention_type flash_attn3 \
--save_video_path ./output_lightx2v_int8.mp4 \
--mm_config '{"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm", "weight_auto_quant": true}'
--mm_config '{"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm", "weight_auto_quant": true}'
......@@ -18,4 +18,4 @@ torchrun --nproc_per_node=4 ${lightx2v_path}/lightx2v/__main__.py \
--target_width 1280 \
--attention_type flash_attn2 \
--mm_config '{"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm", "weight_auto_quant": true}' \
--parallel_attn
\ No newline at end of file
--parallel_attn
......@@ -20,4 +20,4 @@ python ${lightx2v_path}/lightx2v/__main__.py \
--cpu_offload \
--feature_caching TaylorSeer \
--save_video_path ./output_lightx2v_offload_TaylorSeer.mp4 \
# --mm_config '{"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm", "weight_auto_quant": true}'
\ No newline at end of file
# --mm_config '{"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm", "weight_auto_quant": true}'
......@@ -28,4 +28,4 @@ python ${lightx2v_path}/lightx2v/__main__.py \
--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \
--mm_config '{"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm", "weight_auto_quant": true}' \
# --feature_caching Tea \
# --use_ret_steps \
\ No newline at end of file
# --use_ret_steps \
......@@ -28,4 +28,4 @@ python ${lightx2v_path}/lightx2v/__main__.py \
# --mm_config '{"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F", "weight_auto_quant": true}' \
# --feature_caching Tea \
# --use_ret_steps \
# --teacache_thresh 0.2
\ No newline at end of file
# --teacache_thresh 0.2
File mode changed from 100644 to 100755
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment