Commit a50bcc53 authored by Dongz's avatar Dongz Committed by Yang Yong(雍洋)
Browse files

add lint feature and minor fix (#7)

* [minor]: optimize dockerfile for fewer layer

* [feature]: add pre-commit lint, update readme for contribution guidance

* [minor]: fix run shell privileges

* [auto]: first lint without rule F, fix rule E

* [minor]: fix docker file error
parent 3b460075
...@@ -39,28 +39,15 @@ def init_weights(m): ...@@ -39,28 +39,15 @@ def init_weights(m):
nn.init.normal_(m.v.weight, std=m.dim**-0.5) nn.init.normal_(m.v.weight, std=m.dim**-0.5)
nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn) ** -0.5) nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn) ** -0.5)
elif isinstance(m, T5RelativeEmbedding): elif isinstance(m, T5RelativeEmbedding):
nn.init.normal_( nn.init.normal_(m.embedding.weight, std=(2 * m.num_buckets * m.num_heads) ** -0.5)
m.embedding.weight, std=(2 * m.num_buckets * m.num_heads) ** -0.5
)
class GELU(nn.Module): class GELU(nn.Module):
def forward(self, x): def forward(self, x):
return ( return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
0.5
* x
* (
1.0
+ torch.tanh(
math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))
)
)
)
class T5LayerNorm(nn.Module): class T5LayerNorm(nn.Module):
def __init__(self, dim, eps=1e-6): def __init__(self, dim, eps=1e-6):
super(T5LayerNorm, self).__init__() super(T5LayerNorm, self).__init__()
self.dim = dim self.dim = dim
...@@ -75,7 +62,6 @@ class T5LayerNorm(nn.Module): ...@@ -75,7 +62,6 @@ class T5LayerNorm(nn.Module):
class T5Attention(nn.Module): class T5Attention(nn.Module):
def __init__(self, dim, dim_attn, num_heads, dropout=0.1): def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
assert dim_attn % num_heads == 0 assert dim_attn % num_heads == 0
super(T5Attention, self).__init__() super(T5Attention, self).__init__()
...@@ -128,7 +114,6 @@ class T5Attention(nn.Module): ...@@ -128,7 +114,6 @@ class T5Attention(nn.Module):
class T5FeedForward(nn.Module): class T5FeedForward(nn.Module):
def __init__(self, dim, dim_ffn, dropout=0.1): def __init__(self, dim, dim_ffn, dropout=0.1):
super(T5FeedForward, self).__init__() super(T5FeedForward, self).__init__()
self.dim = dim self.dim = dim
...@@ -149,7 +134,6 @@ class T5FeedForward(nn.Module): ...@@ -149,7 +134,6 @@ class T5FeedForward(nn.Module):
class T5SelfAttention(nn.Module): class T5SelfAttention(nn.Module):
def __init__( def __init__(
self, self,
dim, dim,
...@@ -173,11 +157,7 @@ class T5SelfAttention(nn.Module): ...@@ -173,11 +157,7 @@ class T5SelfAttention(nn.Module):
self.attn = T5Attention(dim, dim_attn, num_heads, dropout) self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
self.norm2 = T5LayerNorm(dim) self.norm2 = T5LayerNorm(dim)
self.ffn = T5FeedForward(dim, dim_ffn, dropout) self.ffn = T5FeedForward(dim, dim_ffn, dropout)
self.pos_embedding = ( self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
None
if shared_pos
else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
)
def forward(self, x, mask=None, pos_bias=None): def forward(self, x, mask=None, pos_bias=None):
e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1)) e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
...@@ -187,7 +167,6 @@ class T5SelfAttention(nn.Module): ...@@ -187,7 +167,6 @@ class T5SelfAttention(nn.Module):
class T5CrossAttention(nn.Module): class T5CrossAttention(nn.Module):
def __init__( def __init__(
self, self,
dim, dim,
...@@ -213,27 +192,17 @@ class T5CrossAttention(nn.Module): ...@@ -213,27 +192,17 @@ class T5CrossAttention(nn.Module):
self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout) self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
self.norm3 = T5LayerNorm(dim) self.norm3 = T5LayerNorm(dim)
self.ffn = T5FeedForward(dim, dim_ffn, dropout) self.ffn = T5FeedForward(dim, dim_ffn, dropout)
self.pos_embedding = ( self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False)
None
if shared_pos
else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False)
)
def forward( def forward(self, x, mask=None, encoder_states=None, encoder_mask=None, pos_bias=None):
self, x, mask=None, encoder_states=None, encoder_mask=None, pos_bias=None
):
e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1)) e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e)) x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
x = fp16_clamp( x = fp16_clamp(x + self.cross_attn(self.norm2(x), context=encoder_states, mask=encoder_mask))
x
+ self.cross_attn(self.norm2(x), context=encoder_states, mask=encoder_mask)
)
x = fp16_clamp(x + self.ffn(self.norm3(x))) x = fp16_clamp(x + self.ffn(self.norm3(x)))
return x return x
class T5RelativeEmbedding(nn.Module): class T5RelativeEmbedding(nn.Module):
def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128): def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
super(T5RelativeEmbedding, self).__init__() super(T5RelativeEmbedding, self).__init__()
self.num_buckets = num_buckets self.num_buckets = num_buckets
...@@ -248,9 +217,7 @@ class T5RelativeEmbedding(nn.Module): ...@@ -248,9 +217,7 @@ class T5RelativeEmbedding(nn.Module):
device = self.embedding.weight.device device = self.embedding.weight.device
# rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \ # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
# torch.arange(lq).unsqueeze(1).to(device) # torch.arange(lq).unsqueeze(1).to(device)
rel_pos = torch.arange(lk, device=device).unsqueeze(0) - torch.arange( rel_pos = torch.arange(lk, device=device).unsqueeze(0) - torch.arange(lq, device=device).unsqueeze(1)
lq, device=device
).unsqueeze(1)
rel_pos = self._relative_position_bucket(rel_pos) rel_pos = self._relative_position_bucket(rel_pos)
rel_pos_embeds = self.embedding(rel_pos) rel_pos_embeds = self.embedding(rel_pos)
rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(0) # [1, N, Lq, Lk] rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(0) # [1, N, Lq, Lk]
...@@ -269,23 +236,13 @@ class T5RelativeEmbedding(nn.Module): ...@@ -269,23 +236,13 @@ class T5RelativeEmbedding(nn.Module):
# embeddings for small and large positions # embeddings for small and large positions
max_exact = num_buckets // 2 max_exact = num_buckets // 2
rel_pos_large = ( rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) / math.log(self.max_dist / max_exact) * (num_buckets - max_exact)).long()
max_exact rel_pos_large = torch.min(rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
+ (
torch.log(rel_pos.float() / max_exact)
/ math.log(self.max_dist / max_exact)
* (num_buckets - max_exact)
).long()
)
rel_pos_large = torch.min(
rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1)
)
rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large) rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
return rel_buckets return rel_buckets
class T5Encoder(nn.Module): class T5Encoder(nn.Module):
def __init__( def __init__(
self, self,
vocab, vocab,
...@@ -308,23 +265,10 @@ class T5Encoder(nn.Module): ...@@ -308,23 +265,10 @@ class T5Encoder(nn.Module):
self.shared_pos = shared_pos self.shared_pos = shared_pos
# layers # layers
self.token_embedding = ( self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim) self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True) if shared_pos else None
)
self.pos_embedding = (
T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
if shared_pos
else None
)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.blocks = nn.ModuleList( self.blocks = nn.ModuleList([T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout) for _ in range(num_layers)])
[
T5SelfAttention(
dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout
)
for _ in range(num_layers)
]
)
self.norm = T5LayerNorm(dim) self.norm = T5LayerNorm(dim)
# initialize weights # initialize weights
...@@ -342,7 +286,6 @@ class T5Encoder(nn.Module): ...@@ -342,7 +286,6 @@ class T5Encoder(nn.Module):
class T5Decoder(nn.Module): class T5Decoder(nn.Module):
def __init__( def __init__(
self, self,
vocab, vocab,
...@@ -365,23 +308,10 @@ class T5Decoder(nn.Module): ...@@ -365,23 +308,10 @@ class T5Decoder(nn.Module):
self.shared_pos = shared_pos self.shared_pos = shared_pos
# layers # layers
self.token_embedding = ( self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim) self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False) if shared_pos else None
)
self.pos_embedding = (
T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False)
if shared_pos
else None
)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.blocks = nn.ModuleList( self.blocks = nn.ModuleList([T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout) for _ in range(num_layers)])
[
T5CrossAttention(
dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout
)
for _ in range(num_layers)
]
)
self.norm = T5LayerNorm(dim) self.norm = T5LayerNorm(dim)
# initialize weights # initialize weights
...@@ -408,7 +338,6 @@ class T5Decoder(nn.Module): ...@@ -408,7 +338,6 @@ class T5Decoder(nn.Module):
class T5Model(nn.Module): class T5Model(nn.Module):
def __init__( def __init__(
self, self,
vocab_size, vocab_size,
...@@ -530,7 +459,6 @@ def umt5_xxl(**kwargs): ...@@ -530,7 +459,6 @@ def umt5_xxl(**kwargs):
class T5EncoderModel: class T5EncoderModel:
def __init__( def __init__(
self, self,
text_len, text_len,
...@@ -547,13 +475,7 @@ class T5EncoderModel: ...@@ -547,13 +475,7 @@ class T5EncoderModel:
self.tokenizer_path = tokenizer_path self.tokenizer_path = tokenizer_path
# init model # init model
model = ( model = umt5_xxl(encoder_only=True, return_tokenizer=False, dtype=dtype, device=device).eval().requires_grad_(False)
umt5_xxl(
encoder_only=True, return_tokenizer=False, dtype=dtype, device=device
)
.eval()
.requires_grad_(False)
)
logging.info(f"loading {checkpoint_path}") logging.info(f"loading {checkpoint_path}")
model.load_state_dict(torch.load(checkpoint_path, map_location="cpu", weights_only=True)) model.load_state_dict(torch.load(checkpoint_path, map_location="cpu", weights_only=True))
self.model = model self.model = model
...@@ -562,9 +484,7 @@ class T5EncoderModel: ...@@ -562,9 +484,7 @@ class T5EncoderModel:
else: else:
self.model.to(self.device) self.model.to(self.device)
# init tokenizer # init tokenizer
self.tokenizer = HuggingfaceTokenizer( self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=text_len, clean="whitespace")
name=tokenizer_path, seq_len=text_len, clean="whitespace"
)
def to_cpu(self): def to_cpu(self):
self.model = self.model.to("cpu") self.model = self.model.to("cpu")
......
...@@ -24,10 +24,7 @@ def whitespace_clean(text): ...@@ -24,10 +24,7 @@ def whitespace_clean(text):
def canonicalize(text, keep_punctuation_exact_string=None): def canonicalize(text, keep_punctuation_exact_string=None):
text = text.replace("_", " ") text = text.replace("_", " ")
if keep_punctuation_exact_string: if keep_punctuation_exact_string:
text = keep_punctuation_exact_string.join( text = keep_punctuation_exact_string.join(part.translate(str.maketrans("", "", string.punctuation)) for part in text.split(keep_punctuation_exact_string))
part.translate(str.maketrans("", "", string.punctuation))
for part in text.split(keep_punctuation_exact_string)
)
else: else:
text = text.translate(str.maketrans("", "", string.punctuation)) text = text.translate(str.maketrans("", "", string.punctuation))
text = text.lower() text = text.lower()
...@@ -36,7 +33,6 @@ def canonicalize(text, keep_punctuation_exact_string=None): ...@@ -36,7 +33,6 @@ def canonicalize(text, keep_punctuation_exact_string=None):
class HuggingfaceTokenizer: class HuggingfaceTokenizer:
def __init__(self, name, seq_len=None, clean=None, **kwargs): def __init__(self, name, seq_len=None, clean=None, **kwargs):
assert clean in (None, "whitespace", "lower", "canonicalize") assert clean in (None, "whitespace", "lower", "canonicalize")
self.name = name self.name = name
......
...@@ -123,11 +123,7 @@ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin): ...@@ -123,11 +123,7 @@ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
self.tile_latent_min_tsize = sample_tsize // time_compression_ratio self.tile_latent_min_tsize = sample_tsize // time_compression_ratio
self.tile_sample_min_size = self.config.sample_size self.tile_sample_min_size = self.config.sample_size
sample_size = ( sample_size = self.config.sample_size[0] if isinstance(self.config.sample_size, (list, tuple)) else self.config.sample_size
self.config.sample_size[0]
if isinstance(self.config.sample_size, (list, tuple))
else self.config.sample_size
)
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor = 0.25 self.tile_overlap_factor = 0.25
...@@ -204,9 +200,7 @@ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin): ...@@ -204,9 +200,7 @@ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
return processors return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor( def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False):
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r""" r"""
Sets the attention processor to use to compute attention. Sets the attention processor to use to compute attention.
...@@ -250,16 +244,12 @@ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin): ...@@ -250,16 +244,12 @@ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnProcessor() processor = AttnProcessor()
else: else:
raise ValueError( raise ValueError(f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}")
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
self.set_attn_processor(processor, _remove_lora=True) self.set_attn_processor(processor, _remove_lora=True)
@apply_forward_hook @apply_forward_hook
def encode( def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
self, x: torch.FloatTensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
""" """
Encode a batch of images/videos into latents. Encode a batch of images/videos into latents.
...@@ -312,9 +302,7 @@ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin): ...@@ -312,9 +302,7 @@ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
return DecoderOutput(sample=dec) return DecoderOutput(sample=dec)
@apply_forward_hook @apply_forward_hook
def decode( def decode(self, z: torch.FloatTensor, return_dict: bool = True, generator=None) -> Union[DecoderOutput, torch.FloatTensor]:
self, z: torch.FloatTensor, return_dict: bool = True, generator=None
) -> Union[DecoderOutput, torch.FloatTensor]:
""" """
Decode a batch of images/videos. Decode a batch of images/videos.
...@@ -386,7 +374,7 @@ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin): ...@@ -386,7 +374,7 @@ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
for i in range(0, x.shape[-2], overlap_size): for i in range(0, x.shape[-2], overlap_size):
row = [] row = []
for j in range(0, x.shape[-1], overlap_size): for j in range(0, x.shape[-1], overlap_size):
tile = x[:, :, :, i: i + self.tile_sample_min_size, j: j + self.tile_sample_min_size] tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
tile = self.encoder(tile) tile = self.encoder(tile)
tile = self.quant_conv(tile) tile = self.quant_conv(tile)
row.append(tile) row.append(tile)
...@@ -438,7 +426,7 @@ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin): ...@@ -438,7 +426,7 @@ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
for i in range(0, z.shape[-2], overlap_size): for i in range(0, z.shape[-2], overlap_size):
row = [] row = []
for j in range(0, z.shape[-1], overlap_size): for j in range(0, z.shape[-1], overlap_size):
tile = z[:, :, :, i: i + self.tile_latent_min_size, j: j + self.tile_latent_min_size] tile = z[:, :, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
tile = self.post_quant_conv(tile) tile = self.post_quant_conv(tile)
decoded = self.decoder(tile) decoded = self.decoder(tile)
row.append(decoded) row.append(decoded)
...@@ -463,7 +451,6 @@ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin): ...@@ -463,7 +451,6 @@ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
return DecoderOutput(sample=dec) return DecoderOutput(sample=dec)
def temporal_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: def temporal_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
B, C, T, H, W = x.shape B, C, T, H, W = x.shape
overlap_size = int(self.tile_sample_min_tsize * (1 - self.tile_overlap_factor)) overlap_size = int(self.tile_sample_min_tsize * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_latent_min_tsize * self.tile_overlap_factor) blend_extent = int(self.tile_latent_min_tsize * self.tile_overlap_factor)
...@@ -472,7 +459,7 @@ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin): ...@@ -472,7 +459,7 @@ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
# Split the video into tiles and encode them separately. # Split the video into tiles and encode them separately.
row = [] row = []
for i in range(0, T, overlap_size): for i in range(0, T, overlap_size):
tile = x[:, :, i: i + self.tile_sample_min_tsize + 1, :, :] tile = x[:, :, i : i + self.tile_sample_min_tsize + 1, :, :]
if self.use_spatial_tiling and (tile.shape[-1] > self.tile_sample_min_size or tile.shape[-2] > self.tile_sample_min_size): if self.use_spatial_tiling and (tile.shape[-1] > self.tile_sample_min_size or tile.shape[-2] > self.tile_sample_min_size):
tile = self.spatial_tiled_encode(tile, return_moments=True) tile = self.spatial_tiled_encode(tile, return_moments=True)
else: else:
...@@ -487,7 +474,7 @@ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin): ...@@ -487,7 +474,7 @@ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
tile = self.blend_t(row[i - 1], tile, blend_extent) tile = self.blend_t(row[i - 1], tile, blend_extent)
result_row.append(tile[:, :, :t_limit, :, :]) result_row.append(tile[:, :, :t_limit, :, :])
else: else:
result_row.append(tile[:, :, :t_limit + 1, :, :]) result_row.append(tile[:, :, : t_limit + 1, :, :])
moments = torch.cat(result_row, dim=2) moments = torch.cat(result_row, dim=2)
posterior = DiagonalGaussianDistribution(moments) posterior = DiagonalGaussianDistribution(moments)
...@@ -507,7 +494,7 @@ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin): ...@@ -507,7 +494,7 @@ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
row = [] row = []
for i in range(0, T, overlap_size): for i in range(0, T, overlap_size):
tile = z[:, :, i: i + self.tile_latent_min_tsize + 1, :, :] tile = z[:, :, i : i + self.tile_latent_min_tsize + 1, :, :]
if self.use_spatial_tiling and (tile.shape[-1] > self.tile_latent_min_size or tile.shape[-2] > self.tile_latent_min_size): if self.use_spatial_tiling and (tile.shape[-1] > self.tile_latent_min_size or tile.shape[-2] > self.tile_latent_min_size):
decoded = self.spatial_tiled_decode(tile, return_dict=True).sample decoded = self.spatial_tiled_decode(tile, return_dict=True).sample
else: else:
...@@ -522,7 +509,7 @@ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin): ...@@ -522,7 +509,7 @@ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
tile = self.blend_t(row[i - 1], tile, blend_extent) tile = self.blend_t(row[i - 1], tile, blend_extent)
result_row.append(tile[:, :, :t_limit, :, :]) result_row.append(tile[:, :, :t_limit, :, :])
else: else:
result_row.append(tile[:, :, :t_limit + 1, :, :]) result_row.append(tile[:, :, : t_limit + 1, :, :])
dec = torch.cat(result_row, dim=2) dec = torch.cat(result_row, dim=2)
if not return_dict: if not return_dict:
......
...@@ -3,7 +3,7 @@ import torch ...@@ -3,7 +3,7 @@ import torch
from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D
class VideoEncoderKLCausal3DModel(): class VideoEncoderKLCausal3DModel:
def __init__(self, model_path, dtype, device): def __init__(self, model_path, dtype, device):
self.model_path = model_path self.model_path = model_path
self.dtype = dtype self.dtype = dtype
...@@ -11,10 +11,10 @@ class VideoEncoderKLCausal3DModel(): ...@@ -11,10 +11,10 @@ class VideoEncoderKLCausal3DModel():
self.load() self.load()
def load(self): def load(self):
self.vae_path = os.path.join(self.model_path, 'hunyuan-video-t2v-720p/vae') self.vae_path = os.path.join(self.model_path, "hunyuan-video-t2v-720p/vae")
config = AutoencoderKLCausal3D.load_config(self.vae_path) config = AutoencoderKLCausal3D.load_config(self.vae_path)
self.model = AutoencoderKLCausal3D.from_config(config) self.model = AutoencoderKLCausal3D.from_config(config)
ckpt = torch.load(os.path.join(self.vae_path, 'pytorch_model.pt'), map_location='cpu', weights_only=True) ckpt = torch.load(os.path.join(self.vae_path, "pytorch_model.pt"), map_location="cpu", weights_only=True)
self.model.load_state_dict(ckpt) self.model.load_state_dict(ckpt)
self.model = self.model.to(dtype=self.dtype, device=self.device) self.model = self.model.to(dtype=self.dtype, device=self.device)
self.model.requires_grad_(False) self.model.requires_grad_(False)
...@@ -32,14 +32,13 @@ class VideoEncoderKLCausal3DModel(): ...@@ -32,14 +32,13 @@ class VideoEncoderKLCausal3DModel():
latents = latents / self.model.config.scaling_factor latents = latents / self.model.config.scaling_factor
latents = latents.to(dtype=self.dtype, device=torch.device("cuda")) latents = latents.to(dtype=self.dtype, device=torch.device("cuda"))
self.model.enable_tiling() self.model.enable_tiling()
image = self.model.decode( image = self.model.decode(latents, return_dict=False, generator=generator)[0]
latents, return_dict=False, generator=generator
)[0]
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().float() image = image.cpu().float()
if args.cpu_offload: if args.cpu_offload:
self.to_cpu() self.to_cpu()
return image return image
if __name__ == "__main__": if __name__ == "__main__":
vae_model = VideoEncoderKLCausal3DModel("/mnt/nvme0/yongyang/projects/hy/new/HunyuanVideo/ckpts", dtype=torch.float16, device=torch.device("cuda")) vae_model = VideoEncoderKLCausal3DModel("/mnt/nvme0/yongyang/projects/hy/new/HunyuanVideo/ckpts", dtype=torch.float16, device=torch.device("cuda"))
...@@ -53,16 +53,15 @@ def prepare_causal_attention_mask(n_frame: int, n_hw: int, dtype, device, batch_ ...@@ -53,16 +53,15 @@ def prepare_causal_attention_mask(n_frame: int, n_hw: int, dtype, device, batch_
idx_arr = idx_arr > torch.zeros_like(idx_arr) idx_arr = idx_arr > torch.zeros_like(idx_arr)
for i in range(n_frame): for i in range(n_frame):
for j in range(n_frame): for j in range(n_frame):
if idx_arr[i,j]: if idx_arr[i, j]:
mask[i, j] = torch.zeros(n_hw, n_hw, dtype=dtype, device=device) mask[i, j] = torch.zeros(n_hw, n_hw, dtype=dtype, device=device)
# mask[idx_arr] = torch.zeros(n_hw, n_hw, dtype=dtype, device=device) # mask[idx_arr] = torch.zeros(n_hw, n_hw, dtype=dtype, device=device)
mask = mask.view(n_frame, -1, n_hw).transpose(1, 0).reshape(seq_len, -1).transpose(1,0) mask = mask.view(n_frame, -1, n_hw).transpose(1, 0).reshape(seq_len, -1).transpose(1, 0)
if batch_size is not None: if batch_size is not None:
mask = mask.unsqueeze(0).expand(batch_size, -1, -1) mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
return mask.to(device) return mask.to(device)
class CausalConv3d(nn.Module): class CausalConv3d(nn.Module):
""" """
Implements a causal 3D convolution layer where each position only depends on previous timesteps and current spatial locations. Implements a causal 3D convolution layer where each position only depends on previous timesteps and current spatial locations.
...@@ -76,8 +75,8 @@ class CausalConv3d(nn.Module): ...@@ -76,8 +75,8 @@ class CausalConv3d(nn.Module):
kernel_size: Union[int, Tuple[int, int, int]], kernel_size: Union[int, Tuple[int, int, int]],
stride: Union[int, Tuple[int, int, int]] = 1, stride: Union[int, Tuple[int, int, int]] = 1,
dilation: Union[int, Tuple[int, int, int]] = 1, dilation: Union[int, Tuple[int, int, int]] = 1,
pad_mode='replicate', pad_mode="replicate",
**kwargs **kwargs,
): ):
super().__init__() super().__init__()
...@@ -238,9 +237,7 @@ class DownsampleCausal3D(nn.Module): ...@@ -238,9 +237,7 @@ class DownsampleCausal3D(nn.Module):
raise ValueError(f"unknown norm_type: {norm_type}") raise ValueError(f"unknown norm_type: {norm_type}")
if use_conv: if use_conv:
conv = CausalConv3d( conv = CausalConv3d(self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, bias=bias)
self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, bias=bias
)
else: else:
raise NotImplementedError raise NotImplementedError
...@@ -384,28 +381,18 @@ class ResnetBlockCausal3D(nn.Module): ...@@ -384,28 +381,18 @@ class ResnetBlockCausal3D(nn.Module):
if hidden_states.shape[0] >= 64: if hidden_states.shape[0] >= 64:
input_tensor = input_tensor.contiguous() input_tensor = input_tensor.contiguous()
hidden_states = hidden_states.contiguous() hidden_states = hidden_states.contiguous()
input_tensor = ( input_tensor = self.upsample(input_tensor, scale=scale)
self.upsample(input_tensor, scale=scale) hidden_states = self.upsample(hidden_states, scale=scale)
)
hidden_states = (
self.upsample(hidden_states, scale=scale)
)
elif self.downsample is not None: elif self.downsample is not None:
input_tensor = ( input_tensor = self.downsample(input_tensor, scale=scale)
self.downsample(input_tensor, scale=scale) hidden_states = self.downsample(hidden_states, scale=scale)
)
hidden_states = (
self.downsample(hidden_states, scale=scale)
)
hidden_states = self.conv1(hidden_states) hidden_states = self.conv1(hidden_states)
if self.time_emb_proj is not None: if self.time_emb_proj is not None:
if not self.skip_time_act: if not self.skip_time_act:
temb = self.nonlinearity(temb) temb = self.nonlinearity(temb)
temb = ( temb = self.time_emb_proj(temb, scale)[:, :, None, None]
self.time_emb_proj(temb, scale)[:, :, None, None]
)
if temb is not None and self.time_embedding_norm == "default": if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb hidden_states = hidden_states + temb
...@@ -425,9 +412,7 @@ class ResnetBlockCausal3D(nn.Module): ...@@ -425,9 +412,7 @@ class ResnetBlockCausal3D(nn.Module):
hidden_states = self.conv2(hidden_states) hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None: if self.conv_shortcut is not None:
input_tensor = ( input_tensor = self.conv_shortcut(input_tensor)
self.conv_shortcut(input_tensor)
)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
...@@ -464,9 +449,7 @@ def get_down_block3d( ...@@ -464,9 +449,7 @@ def get_down_block3d(
): ):
# If attn head dim is not defined, we default it to the number of heads # If attn head dim is not defined, we default it to the number of heads
if attention_head_dim is None: if attention_head_dim is None:
logger.warn( logger.warn(f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}.")
f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
)
attention_head_dim = num_attention_heads attention_head_dim = num_attention_heads
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
...@@ -518,9 +501,7 @@ def get_up_block3d( ...@@ -518,9 +501,7 @@ def get_up_block3d(
) -> nn.Module: ) -> nn.Module:
# If attn head dim is not defined, we default it to the number of heads # If attn head dim is not defined, we default it to the number of heads
if attention_head_dim is None: if attention_head_dim is None:
logger.warn( logger.warn(f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}.")
f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
)
attention_head_dim = num_attention_heads attention_head_dim = num_attention_heads
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
...@@ -588,9 +569,7 @@ class UNetMidBlockCausal3D(nn.Module): ...@@ -588,9 +569,7 @@ class UNetMidBlockCausal3D(nn.Module):
attentions = [] attentions = []
if attention_head_dim is None: if attention_head_dim is None:
logger.warn( logger.warn(f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}.")
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
)
attention_head_dim = in_channels attention_head_dim = in_channels
for _ in range(num_layers): for _ in range(num_layers):
...@@ -637,9 +616,7 @@ class UNetMidBlockCausal3D(nn.Module): ...@@ -637,9 +616,7 @@ class UNetMidBlockCausal3D(nn.Module):
if attn is not None: if attn is not None:
B, C, T, H, W = hidden_states.shape B, C, T, H, W = hidden_states.shape
hidden_states = rearrange(hidden_states, "b c f h w -> b (f h w) c") hidden_states = rearrange(hidden_states, "b c f h w -> b (f h w) c")
attention_mask = prepare_causal_attention_mask( attention_mask = prepare_causal_attention_mask(T, H * W, hidden_states.dtype, hidden_states.device, batch_size=B)
T, H * W, hidden_states.dtype, hidden_states.device, batch_size=B
)
hidden_states = attn(hidden_states, temb=temb, attention_mask=attention_mask) hidden_states = attn(hidden_states, temb=temb, attention_mask=attention_mask)
hidden_states = rearrange(hidden_states, "b (f h w) c -> b c f h w", f=T, h=H, w=W) hidden_states = rearrange(hidden_states, "b (f h w) c -> b c f h w", f=T, h=H, w=W)
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
...@@ -770,9 +747,7 @@ class UpDecoderBlockCausal3D(nn.Module): ...@@ -770,9 +747,7 @@ class UpDecoderBlockCausal3D(nn.Module):
self.resolution_idx = resolution_idx self.resolution_idx = resolution_idx
def forward( def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0) -> torch.FloatTensor:
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
) -> torch.FloatTensor:
for resnet in self.resnets: for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb=temb, scale=scale) hidden_states = resnet(hidden_states, temb=temb, scale=scale)
......
...@@ -66,10 +66,7 @@ class EncoderCausal3D(nn.Module): ...@@ -66,10 +66,7 @@ class EncoderCausal3D(nn.Module):
if time_compression_ratio == 4: if time_compression_ratio == 4:
add_spatial_downsample = bool(i < num_spatial_downsample_layers) add_spatial_downsample = bool(i < num_spatial_downsample_layers)
add_time_downsample = bool( add_time_downsample = bool(i >= (len(block_out_channels) - 1 - num_time_downsample_layers) and not is_final_block)
i >= (len(block_out_channels) - 1 - num_time_downsample_layers)
and not is_final_block
)
else: else:
raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.") raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.")
...@@ -186,10 +183,7 @@ class DecoderCausal3D(nn.Module): ...@@ -186,10 +183,7 @@ class DecoderCausal3D(nn.Module):
if time_compression_ratio == 4: if time_compression_ratio == 4:
add_spatial_upsample = bool(i < num_spatial_upsample_layers) add_spatial_upsample = bool(i < num_spatial_upsample_layers)
add_time_upsample = bool( add_time_upsample = bool(i >= len(block_out_channels) - 1 - num_time_upsample_layers and not is_final_block)
i >= len(block_out_channels) - 1 - num_time_upsample_layers
and not is_final_block
)
else: else:
raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.") raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.")
...@@ -263,9 +257,7 @@ class DecoderCausal3D(nn.Module): ...@@ -263,9 +257,7 @@ class DecoderCausal3D(nn.Module):
) )
else: else:
# middle # middle
sample = torch.utils.checkpoint.checkpoint( sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample, latent_embeds)
create_custom_forward(self.mid_block), sample, latent_embeds
)
sample = sample.to(upscale_dtype) sample = sample.to(upscale_dtype)
# up # up
...@@ -306,9 +298,7 @@ class DiagonalGaussianDistribution(object): ...@@ -306,9 +298,7 @@ class DiagonalGaussianDistribution(object):
self.std = torch.exp(0.5 * self.logvar) self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar) self.var = torch.exp(self.logvar)
if self.deterministic: if self.deterministic:
self.var = self.std = torch.zeros_like( self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device, dtype=self.parameters.dtype)
self.mean, device=self.parameters.device, dtype=self.parameters.dtype
)
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor: def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
# make sure sample is on the same device as the parameters and has same dtype # make sure sample is on the same device as the parameters and has same dtype
...@@ -333,11 +323,7 @@ class DiagonalGaussianDistribution(object): ...@@ -333,11 +323,7 @@ class DiagonalGaussianDistribution(object):
) )
else: else:
return 0.5 * torch.sum( return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var torch.pow(self.mean - other.mean, 2) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar,
+ self.var / other.var
- 1.0
- self.logvar
+ other.logvar,
dim=reduce_dim, dim=reduce_dim,
) )
...@@ -346,8 +332,7 @@ class DiagonalGaussianDistribution(object): ...@@ -346,8 +332,7 @@ class DiagonalGaussianDistribution(object):
return torch.Tensor([0.0]) return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi) logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum( return 0.5 * torch.sum(
logtwopi + self.logvar + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
torch.pow(sample - self.mean, 2) / self.var,
dim=dims, dim=dims,
) )
......
...@@ -44,7 +44,6 @@ class CausalConv3d(nn.Conv3d): ...@@ -44,7 +44,6 @@ class CausalConv3d(nn.Conv3d):
class RMS_norm(nn.Module): class RMS_norm(nn.Module):
def __init__(self, dim, channel_first=True, images=True, bias=False): def __init__(self, dim, channel_first=True, images=True, bias=False):
super().__init__() super().__init__()
broadcastable_dims = (1, 1, 1) if not images else (1, 1) broadcastable_dims = (1, 1, 1) if not images else (1, 1)
...@@ -56,16 +55,10 @@ class RMS_norm(nn.Module): ...@@ -56,16 +55,10 @@ class RMS_norm(nn.Module):
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
def forward(self, x): def forward(self, x):
return ( return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
F.normalize(x, dim=(1 if self.channel_first else -1))
* self.scale
* self.gamma
+ self.bias
)
class Upsample(nn.Upsample): class Upsample(nn.Upsample):
def forward(self, x): def forward(self, x):
""" """
Fix bfloat16 support for nearest neighbor interpolation. Fix bfloat16 support for nearest neighbor interpolation.
...@@ -74,7 +67,6 @@ class Upsample(nn.Upsample): ...@@ -74,7 +67,6 @@ class Upsample(nn.Upsample):
class Resample(nn.Module): class Resample(nn.Module):
def __init__(self, dim, mode): def __init__(self, dim, mode):
assert mode in ( assert mode in (
"none", "none",
...@@ -101,16 +93,10 @@ class Resample(nn.Module): ...@@ -101,16 +93,10 @@ class Resample(nn.Module):
self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
elif mode == "downsample2d": elif mode == "downsample2d":
self.resample = nn.Sequential( self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))
)
elif mode == "downsample3d": elif mode == "downsample3d":
self.resample = nn.Sequential( self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)) self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
)
self.time_conv = CausalConv3d(
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)
)
else: else:
self.resample = nn.Identity() self.resample = nn.Identity()
...@@ -124,28 +110,17 @@ class Resample(nn.Module): ...@@ -124,28 +110,17 @@ class Resample(nn.Module):
feat_cache[idx] = "Rep" feat_cache[idx] = "Rep"
feat_idx[0] += 1 feat_idx[0] += 1
else: else:
cache_x = x[:, :, -CACHE_T:, :, :].clone() cache_x = x[:, :, -CACHE_T:, :, :].clone()
if ( if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
cache_x.shape[2] < 2
and feat_cache[idx] is not None
and feat_cache[idx] != "Rep"
):
# cache last frame of last two chunk # cache last frame of last two chunk
cache_x = torch.cat( cache_x = torch.cat(
[ [
feat_cache[idx][:, :, -1, :, :] feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
.unsqueeze(2)
.to(cache_x.device),
cache_x, cache_x,
], ],
dim=2, dim=2,
) )
if ( if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep":
cache_x.shape[2] < 2
and feat_cache[idx] is not None
and feat_cache[idx] == "Rep"
):
cache_x = torch.cat( cache_x = torch.cat(
[torch.zeros_like(cache_x).to(cache_x.device), cache_x], [torch.zeros_like(cache_x).to(cache_x.device), cache_x],
dim=2, dim=2,
...@@ -172,15 +147,12 @@ class Resample(nn.Module): ...@@ -172,15 +147,12 @@ class Resample(nn.Module):
feat_cache[idx] = x.clone() feat_cache[idx] = x.clone()
feat_idx[0] += 1 feat_idx[0] += 1
else: else:
cache_x = x[:, :, -1:, :, :].clone() cache_x = x[:, :, -1:, :, :].clone()
# if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep': # if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
# # cache last frame of last two chunk # # cache last frame of last two chunk
# cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) # cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.time_conv( x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)
)
feat_cache[idx] = cache_x feat_cache[idx] = cache_x
feat_idx[0] += 1 feat_idx[0] += 1
return x return x
...@@ -210,7 +182,6 @@ class Resample(nn.Module): ...@@ -210,7 +182,6 @@ class Resample(nn.Module):
class ResidualBlock(nn.Module): class ResidualBlock(nn.Module):
def __init__(self, in_dim, out_dim, dropout=0.0): def __init__(self, in_dim, out_dim, dropout=0.0):
super().__init__() super().__init__()
self.in_dim = in_dim self.in_dim = in_dim
...@@ -226,9 +197,7 @@ class ResidualBlock(nn.Module): ...@@ -226,9 +197,7 @@ class ResidualBlock(nn.Module):
nn.Dropout(dropout), nn.Dropout(dropout),
CausalConv3d(out_dim, out_dim, 3, padding=1), CausalConv3d(out_dim, out_dim, 3, padding=1),
) )
self.shortcut = ( self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
)
def forward(self, x, feat_cache=None, feat_idx=[0]): def forward(self, x, feat_cache=None, feat_idx=[0]):
h = self.shortcut(x) h = self.shortcut(x)
...@@ -240,9 +209,7 @@ class ResidualBlock(nn.Module): ...@@ -240,9 +209,7 @@ class ResidualBlock(nn.Module):
# cache last frame of last two chunk # cache last frame of last two chunk
cache_x = torch.cat( cache_x = torch.cat(
[ [
feat_cache[idx][:, :, -1, :, :] feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
.unsqueeze(2)
.to(cache_x.device),
cache_x, cache_x,
], ],
dim=2, dim=2,
...@@ -278,13 +245,7 @@ class AttentionBlock(nn.Module): ...@@ -278,13 +245,7 @@ class AttentionBlock(nn.Module):
x = rearrange(x, "b c t h w -> (b t) c h w") x = rearrange(x, "b c t h w -> (b t) c h w")
x = self.norm(x) x = self.norm(x)
# compute query, key, value # compute query, key, value
q, k, v = ( q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(0, 1, 3, 2).contiguous().chunk(3, dim=-1)
self.to_qkv(x)
.reshape(b * t, 1, c * 3, -1)
.permute(0, 1, 3, 2)
.contiguous()
.chunk(3, dim=-1)
)
# apply attention # apply attention
x = F.scaled_dot_product_attention( x = F.scaled_dot_product_attention(
...@@ -301,7 +262,6 @@ class AttentionBlock(nn.Module): ...@@ -301,7 +262,6 @@ class AttentionBlock(nn.Module):
class Encoder3d(nn.Module): class Encoder3d(nn.Module):
def __init__( def __init__(
self, self,
dim=128, dim=128,
...@@ -400,9 +360,7 @@ class Encoder3d(nn.Module): ...@@ -400,9 +360,7 @@ class Encoder3d(nn.Module):
# cache last frame of last two chunk # cache last frame of last two chunk
cache_x = torch.cat( cache_x = torch.cat(
[ [
feat_cache[idx][:, :, -1, :, :] feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
.unsqueeze(2)
.to(cache_x.device),
cache_x, cache_x,
], ],
dim=2, dim=2,
...@@ -416,7 +374,6 @@ class Encoder3d(nn.Module): ...@@ -416,7 +374,6 @@ class Encoder3d(nn.Module):
class Decoder3d(nn.Module): class Decoder3d(nn.Module):
def __init__( def __init__(
self, self,
dim=128, dim=128,
...@@ -518,9 +475,7 @@ class Decoder3d(nn.Module): ...@@ -518,9 +475,7 @@ class Decoder3d(nn.Module):
# cache last frame of last two chunk # cache last frame of last two chunk
cache_x = torch.cat( cache_x = torch.cat(
[ [
feat_cache[idx][:, :, -1, :, :] feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
.unsqueeze(2)
.to(cache_x.device),
cache_x, cache_x,
], ],
dim=2, dim=2,
...@@ -542,7 +497,6 @@ def count_conv3d(model): ...@@ -542,7 +497,6 @@ def count_conv3d(model):
class WanVAE_(nn.Module): class WanVAE_(nn.Module):
def __init__( def __init__(
self, self,
dim=128, dim=128,
...@@ -613,9 +567,7 @@ class WanVAE_(nn.Module): ...@@ -613,9 +567,7 @@ class WanVAE_(nn.Module):
out = torch.cat([out, out_], 2) out = torch.cat([out, out_], 2)
mu, log_var = self.conv1(out).chunk(2, dim=1) mu, log_var = self.conv1(out).chunk(2, dim=1)
if isinstance(scale[0], torch.Tensor): if isinstance(scale[0], torch.Tensor):
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view( mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1)
1, self.z_dim, 1, 1, 1
)
else: else:
mu = (mu - scale[0]) * scale[1] mu = (mu - scale[0]) * scale[1]
self.clear_cache() self.clear_cache()
...@@ -625,9 +577,7 @@ class WanVAE_(nn.Module): ...@@ -625,9 +577,7 @@ class WanVAE_(nn.Module):
self.clear_cache() self.clear_cache()
# z: [b,c,t,h,w] # z: [b,c,t,h,w]
if isinstance(scale[0], torch.Tensor): if isinstance(scale[0], torch.Tensor):
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view( z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1)
1, self.z_dim, 1, 1, 1
)
else: else:
z = z / scale[1] + scale[0] z = z / scale[1] + scale[0]
iter_ = z.shape[2] iter_ = z.shape[2]
...@@ -700,7 +650,6 @@ def _video_vae(pretrained_path=None, z_dim=None, device="cpu", **kwargs): ...@@ -700,7 +650,6 @@ def _video_vae(pretrained_path=None, z_dim=None, device="cpu", **kwargs):
class WanVAE: class WanVAE:
def __init__( def __init__(
self, self,
z_dim=16, z_dim=16,
...@@ -780,10 +729,7 @@ class WanVAE: ...@@ -780,10 +729,7 @@ class WanVAE:
""" """
videos: A list of videos each with shape [C, T, H, W]. videos: A list of videos each with shape [C, T, H, W].
""" """
return [ return [self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0) for u in videos]
self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0)
for u in videos
]
def decode_dist(self, zs, world_size, cur_rank, split_dim): def decode_dist(self, zs, world_size, cur_rank, split_dim):
splited_total_len = zs.shape[split_dim] splited_total_len = zs.shape[split_dim]
...@@ -792,37 +738,37 @@ class WanVAE: ...@@ -792,37 +738,37 @@ class WanVAE:
if cur_rank == 0: if cur_rank == 0:
if split_dim == 2: if split_dim == 2:
zs = zs[:,:,:splited_chunk_len+2*padding_size,:].contiguous() zs = zs[:, :, : splited_chunk_len + 2 * padding_size, :].contiguous()
elif split_dim == 3: elif split_dim == 3:
zs = zs[:,:,:,:splited_chunk_len+2*padding_size].contiguous() zs = zs[:, :, :, : splited_chunk_len + 2 * padding_size].contiguous()
elif cur_rank == world_size-1: elif cur_rank == world_size - 1:
if split_dim == 2: if split_dim == 2:
zs = zs[:,:,-(splited_chunk_len+2*padding_size):,:].contiguous() zs = zs[:, :, -(splited_chunk_len + 2 * padding_size) :, :].contiguous()
elif split_dim == 3: elif split_dim == 3:
zs = zs[:,:,:,-(splited_chunk_len+2*padding_size):].contiguous() zs = zs[:, :, :, -(splited_chunk_len + 2 * padding_size) :].contiguous()
else: else:
if split_dim == 2: if split_dim == 2:
zs = zs[:,:,cur_rank*splited_chunk_len-padding_size:(cur_rank+1)*splited_chunk_len+padding_size,:].contiguous() zs = zs[:, :, cur_rank * splited_chunk_len - padding_size : (cur_rank + 1) * splited_chunk_len + padding_size, :].contiguous()
elif split_dim == 3: elif split_dim == 3:
zs = zs[:,:,:,cur_rank*splited_chunk_len-padding_size:(cur_rank+1)*splited_chunk_len+padding_size].contiguous() zs = zs[:, :, :, cur_rank * splited_chunk_len - padding_size : (cur_rank + 1) * splited_chunk_len + padding_size].contiguous()
images = self.model.decode(zs.unsqueeze(0), self.scale).float().clamp_(-1, 1) images = self.model.decode(zs.unsqueeze(0), self.scale).float().clamp_(-1, 1)
if cur_rank == 0: if cur_rank == 0:
if split_dim == 2: if split_dim == 2:
images = images[:,:,:,:splited_chunk_len*8,:].contiguous() images = images[:, :, :, : splited_chunk_len * 8, :].contiguous()
elif split_dim == 3: elif split_dim == 3:
images = images[:,:,:,:,:splited_chunk_len*8].contiguous() images = images[:, :, :, :, : splited_chunk_len * 8].contiguous()
elif cur_rank == world_size-1: elif cur_rank == world_size - 1:
if split_dim == 2: if split_dim == 2:
images = images[:,:,:,-splited_chunk_len*8:,:].contiguous() images = images[:, :, :, -splited_chunk_len * 8 :, :].contiguous()
elif split_dim == 3: elif split_dim == 3:
images = images[:,:,:,:,-splited_chunk_len*8:].contiguous() images = images[:, :, :, :, -splited_chunk_len * 8 :].contiguous()
else: else:
if split_dim == 2: if split_dim == 2:
images = images[:,:,:,8*padding_size:-8*padding_size,:].contiguous() images = images[:, :, :, 8 * padding_size : -8 * padding_size, :].contiguous()
elif split_dim == 3: elif split_dim == 3:
images = images[:,:,:,:,8*padding_size:-8*padding_size].contiguous() images = images[:, :, :, :, 8 * padding_size : -8 * padding_size].contiguous()
full_images = [torch.empty_like(images) for _ in range(world_size)] full_images = [torch.empty_like(images) for _ in range(world_size)]
dist.all_gather(full_images, images) dist.all_gather(full_images, images)
...@@ -833,7 +779,6 @@ class WanVAE: ...@@ -833,7 +779,6 @@ class WanVAE:
return images return images
def decode(self, zs, generator, args): def decode(self, zs, generator, args):
if args.cpu_offload: if args.cpu_offload:
self.to_cuda() self.to_cuda()
......
...@@ -5,7 +5,7 @@ from lightx2v.text2v.models.video_encoders.hf.autoencoder_kl_causal_3d.autoencod ...@@ -5,7 +5,7 @@ from lightx2v.text2v.models.video_encoders.hf.autoencoder_kl_causal_3d.autoencod
from lightx2v.text2v.models.video_encoders.trt.autoencoder_kl_causal_3d import trt_vae_infer from lightx2v.text2v.models.video_encoders.trt.autoencoder_kl_causal_3d import trt_vae_infer
class VideoEncoderKLCausal3DModel(): class VideoEncoderKLCausal3DModel:
def __init__(self, model_path, dtype, device): def __init__(self, model_path, dtype, device):
self.model_path = model_path self.model_path = model_path
self.dtype = dtype self.dtype = dtype
...@@ -13,10 +13,10 @@ class VideoEncoderKLCausal3DModel(): ...@@ -13,10 +13,10 @@ class VideoEncoderKLCausal3DModel():
self.load() self.load()
def load(self): def load(self):
self.vae_path = os.path.join(self.model_path, 'hunyuan-video-t2v-720p/vae') self.vae_path = os.path.join(self.model_path, "hunyuan-video-t2v-720p/vae")
config = AutoencoderKLCausal3D.load_config(self.vae_path) config = AutoencoderKLCausal3D.load_config(self.vae_path)
self.model = AutoencoderKLCausal3D.from_config(config) self.model = AutoencoderKLCausal3D.from_config(config)
ckpt = torch.load(os.path.join(self.vae_path, 'pytorch_model.pt'), map_location='cpu', weights_only=True) ckpt = torch.load(os.path.join(self.vae_path, "pytorch_model.pt"), map_location="cpu", weights_only=True)
self.model.load_state_dict(ckpt) self.model.load_state_dict(ckpt)
self.model = self.model.to(dtype=self.dtype, device=self.device) self.model = self.model.to(dtype=self.dtype, device=self.device)
self.model.requires_grad_(False) self.model.requires_grad_(False)
...@@ -28,12 +28,11 @@ class VideoEncoderKLCausal3DModel(): ...@@ -28,12 +28,11 @@ class VideoEncoderKLCausal3DModel():
latents = latents / self.model.config.scaling_factor latents = latents / self.model.config.scaling_factor
latents = latents.to(dtype=self.dtype, device=self.device) latents = latents.to(dtype=self.dtype, device=self.device)
self.model.enable_tiling() self.model.enable_tiling()
image = self.model.decode( image = self.model.decode(latents, return_dict=False, generator=generator)[0]
latents, return_dict=False, generator=generator
)[0]
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().float() image = image.cpu().float()
return image return image
if __name__ == "__main__": if __name__ == "__main__":
vae_model = VideoEncoderKLCausal3DModel("/mnt/nvme1/yongyang/models/hy/ckpts", dtype=torch.float16, device=torch.device("cuda")) vae_model = VideoEncoderKLCausal3DModel("/mnt/nvme1/yongyang/models/hy/ckpts", dtype=torch.float16, device=torch.device("cuda"))
...@@ -100,18 +100,18 @@ class HyVaeTrtModelInfer(nn.Module): ...@@ -100,18 +100,18 @@ class HyVaeTrtModelInfer(nn.Module):
device = batch.device device = batch.device
dtype = batch.dtype dtype = batch.dtype
batch = batch.cpu().numpy() batch = batch.cpu().numpy()
def get_output_shape(shp): def get_output_shape(shp):
b, c, t, h, w = shp b, c, t, h, w = shp
out = (b, 3, 4*(t-1)+1, h*8, w*8) out = (b, 3, 4 * (t - 1) + 1, h * 8, w * 8)
return out return out
shp_dict = {"inp": batch.shape, "out": get_output_shape(batch.shape)} shp_dict = {"inp": batch.shape, "out": get_output_shape(batch.shape)}
self.alloc(shp_dict) self.alloc(shp_dict)
output = np.zeros(*self.output_spec()) output = np.zeros(*self.output_spec())
# Process I/O and execute the network # Process I/O and execute the network
common.memcpy_host_to_device( common.memcpy_host_to_device(self.inputs[0]["allocation"], np.ascontiguousarray(batch))
self.inputs[0]["allocation"], np.ascontiguousarray(batch)
)
self.context.execute_v2(self.allocations) self.context.execute_v2(self.allocations)
common.memcpy_device_to_host(output, self.outputs[0]["allocation"]) common.memcpy_device_to_host(output, self.outputs[0]["allocation"])
output = torch.from_numpy(output).to(device).type(dtype) output = torch.from_numpy(output).to(device).type(dtype)
...@@ -122,18 +122,15 @@ class HyVaeTrtModelInfer(nn.Module): ...@@ -122,18 +122,15 @@ class HyVaeTrtModelInfer(nn.Module):
logger.info("Start to do VAE onnx exporting.") logger.info("Start to do VAE onnx exporting.")
device = next(decoder.parameters())[0].device device = next(decoder.parameters())[0].device
example_inp = torch.rand(1, 16, 17, 32, 32).to(device).type(next(decoder.parameters())[0].dtype) example_inp = torch.rand(1, 16, 17, 32, 32).to(device).type(next(decoder.parameters())[0].dtype)
out_path = str(Path(str(model_dir))/"vae_decoder.onnx") out_path = str(Path(str(model_dir)) / "vae_decoder.onnx")
torch.onnx.export( torch.onnx.export(
decoder.eval().half(), decoder.eval().half(),
example_inp.half(), example_inp.half(),
out_path, out_path,
input_names=['inp'], input_names=["inp"],
output_names=['out'], output_names=["out"],
opset_version=14, opset_version=14,
dynamic_axes={ dynamic_axes={"inp": {1: "c1", 2: "c2", 3: "c3", 4: "c4"}, "out": {1: "c1", 2: "c2", 3: "c3", 4: "c4"}},
"inp": {1: "c1", 2: "c2", 3: "c3", 4: "c4"},
"out": {1: "c1", 2: "c2", 3: "c3", 4: "c4"}
}
) )
# onnx_ori = onnx.load(out_path) # onnx_ori = onnx.load(out_path)
os.system(f"onnxsim {out_path} {out_path}") os.system(f"onnxsim {out_path} {out_path}")
......
...@@ -8,20 +8,20 @@ class BaseQuantizer(object): ...@@ -8,20 +8,20 @@ class BaseQuantizer(object):
self.sym = symmetric self.sym = symmetric
self.granularity = granularity self.granularity = granularity
self.kwargs = kwargs self.kwargs = kwargs
if self.granularity == 'per_group': if self.granularity == "per_group":
self.group_size = self.kwargs['group_size'] self.group_size = self.kwargs["group_size"]
self.calib_algo = self.kwargs.get('calib_algo', 'minmax') self.calib_algo = self.kwargs.get("calib_algo", "minmax")
def get_tensor_range(self, tensor): def get_tensor_range(self, tensor):
if self.calib_algo == 'minmax': if self.calib_algo == "minmax":
return self.get_minmax_range(tensor) return self.get_minmax_range(tensor)
elif self.calib_algo == 'mse': elif self.calib_algo == "mse":
return self.get_mse_range(tensor) return self.get_mse_range(tensor)
else: else:
raise ValueError(f'Unsupported calibration algorithm: {self.calib_algo}') raise ValueError(f"Unsupported calibration algorithm: {self.calib_algo}")
def get_minmax_range(self, tensor): def get_minmax_range(self, tensor):
if self.granularity == 'per_tensor': if self.granularity == "per_tensor":
max_val = torch.max(tensor) max_val = torch.max(tensor)
min_val = torch.min(tensor) min_val = torch.min(tensor)
else: else:
...@@ -47,7 +47,7 @@ class BaseQuantizer(object): ...@@ -47,7 +47,7 @@ class BaseQuantizer(object):
return scales, zeros, qmax, qmin return scales, zeros, qmax, qmin
def reshape_tensor(self, tensor, allow_padding=False): def reshape_tensor(self, tensor, allow_padding=False):
if self.granularity == 'per_group': if self.granularity == "per_group":
t = tensor.reshape(-1, self.group_size) t = tensor.reshape(-1, self.group_size)
else: else:
t = tensor t = tensor
...@@ -79,7 +79,7 @@ class BaseQuantizer(object): ...@@ -79,7 +79,7 @@ class BaseQuantizer(object):
tensor, scales, zeros, qmax, qmin = self.get_tensor_qparams(tensor) tensor, scales, zeros, qmax, qmin = self.get_tensor_qparams(tensor)
tensor = self.quant(tensor, scales, zeros, qmax, qmin) tensor = self.quant(tensor, scales, zeros, qmax, qmin)
tensor = self.restore_tensor(tensor, org_shape) tensor = self.restore_tensor(tensor, org_shape)
if self.sym == True: if self.sym:
zeros = None zeros = None
return tensor, scales, zeros return tensor, scales, zeros
...@@ -87,9 +87,9 @@ class BaseQuantizer(object): ...@@ -87,9 +87,9 @@ class BaseQuantizer(object):
class IntegerQuantizer(BaseQuantizer): class IntegerQuantizer(BaseQuantizer):
def __init__(self, bit, symmetric, granularity, **kwargs): def __init__(self, bit, symmetric, granularity, **kwargs):
super().__init__(bit, symmetric, granularity, **kwargs) super().__init__(bit, symmetric, granularity, **kwargs)
if 'int_range' in self.kwargs: if "int_range" in self.kwargs:
self.qmin = self.kwargs['int_range'][0] self.qmin = self.kwargs["int_range"][0]
self.qmax = self.kwargs['int_range'][1] self.qmax = self.kwargs["int_range"][1]
else: else:
if self.sym: if self.sym:
self.qmin = -(2 ** (self.bit - 1)) self.qmin = -(2 ** (self.bit - 1))
...@@ -110,7 +110,14 @@ class IntegerQuantizer(BaseQuantizer): ...@@ -110,7 +110,14 @@ class IntegerQuantizer(BaseQuantizer):
tensor = (tensor - zeros) * scales tensor = (tensor - zeros) * scales
return tensor return tensor
def quant_dequant(self, tensor, scales, zeros, qmax, qmin,): def quant_dequant(
self,
tensor,
scales,
zeros,
qmax,
qmin,
):
tensor = self.quant(tensor, scales, zeros, qmax, qmin) tensor = self.quant(tensor, scales, zeros, qmax, qmin)
tensor = self.dequant(tensor, scales, zeros) tensor = self.dequant(tensor, scales, zeros)
return tensor return tensor
...@@ -119,19 +126,19 @@ class IntegerQuantizer(BaseQuantizer): ...@@ -119,19 +126,19 @@ class IntegerQuantizer(BaseQuantizer):
class FloatQuantizer(BaseQuantizer): class FloatQuantizer(BaseQuantizer):
def __init__(self, bit, symmetric, granularity, **kwargs): def __init__(self, bit, symmetric, granularity, **kwargs):
super().__init__(bit, symmetric, granularity, **kwargs) super().__init__(bit, symmetric, granularity, **kwargs)
assert self.bit in ['e4m3', 'e5m2'], f'Unsupported bit configuration: {self.bit}' assert self.bit in ["e4m3", "e5m2"], f"Unsupported bit configuration: {self.bit}"
assert self.sym == True assert self.sym
if self.bit == 'e4m3': if self.bit == "e4m3":
self.e_bits = 4 self.e_bits = 4
self.m_bits = 3 self.m_bits = 3
self.fp_dtype = torch.float8_e4m3fn self.fp_dtype = torch.float8_e4m3fn
elif self.bit == 'e5m2': elif self.bit == "e5m2":
self.e_bits = 5 self.e_bits = 5
self.m_bits = 2 self.m_bits = 2
self.fp_dtype = torch.float8_e5m2 self.fp_dtype = torch.float8_e5m2
else: else:
raise ValueError(f'Unsupported bit configuration: {self.bit}') raise ValueError(f"Unsupported bit configuration: {self.bit}")
finfo = torch.finfo(self.fp_dtype) finfo = torch.finfo(self.fp_dtype)
self.qmin, self.qmax = finfo.min, finfo.max self.qmin, self.qmax = finfo.min, finfo.max
...@@ -141,13 +148,9 @@ class FloatQuantizer(BaseQuantizer): ...@@ -141,13 +148,9 @@ class FloatQuantizer(BaseQuantizer):
def quant(self, tensor, scales, zeros, qmax, qmin): def quant(self, tensor, scales, zeros, qmax, qmin):
scaled_tensor = tensor / scales + zeros scaled_tensor = tensor / scales + zeros
scaled_tensor = torch.clip( scaled_tensor = torch.clip(scaled_tensor, self.qmin.cuda(), self.qmax.cuda())
scaled_tensor, self.qmin.cuda(), self.qmax.cuda()
)
org_dtype = scaled_tensor.dtype org_dtype = scaled_tensor.dtype
q_tensor = float_quantize( q_tensor = float_quantize(scaled_tensor.float(), self.e_bits, self.m_bits, rounding="nearest")
scaled_tensor.float(), self.e_bits, self.m_bits, rounding='nearest'
)
q_tensor.to(org_dtype) q_tensor.to(org_dtype)
return q_tensor return q_tensor
...@@ -161,9 +164,9 @@ class FloatQuantizer(BaseQuantizer): ...@@ -161,9 +164,9 @@ class FloatQuantizer(BaseQuantizer):
return tensor return tensor
if __name__ == '__main__': if __name__ == "__main__":
weight = torch.randn(4096, 4096, dtype=torch.bfloat16).cuda() weight = torch.randn(4096, 4096, dtype=torch.bfloat16).cuda()
quantizer = IntegerQuantizer(4, False, 'per_group', group_size=128) quantizer = IntegerQuantizer(4, False, "per_group", group_size=128)
q_weight = quantizer.fake_quant_tensor(weight) q_weight = quantizer.fake_quant_tensor(weight)
print(weight) print(weight)
print(q_weight) print(q_weight)
...@@ -175,7 +178,7 @@ if __name__ == '__main__': ...@@ -175,7 +178,7 @@ if __name__ == '__main__':
print(f"zeros = {zeros}, {zeros.shape}") print(f"zeros = {zeros}, {zeros.shape}")
weight = torch.randn(8192, 4096, dtype=torch.bfloat16).cuda() weight = torch.randn(8192, 4096, dtype=torch.bfloat16).cuda()
quantizer = FloatQuantizer('e4m3', True, 'per_channel') quantizer = FloatQuantizer("e4m3", True, "per_channel")
q_weight = quantizer.fake_quant_tensor(weight) q_weight = quantizer.fake_quant_tensor(weight)
print(weight) print(weight)
print(q_weight) print(q_weight)
......
...@@ -11,13 +11,13 @@ class Register(dict): ...@@ -11,13 +11,13 @@ class Register(dict):
def register(self, target, key=None): def register(self, target, key=None):
if not callable(target): if not callable(target):
raise Exception(f'Error: {target} must be callable!') raise Exception(f"Error: {target} must be callable!")
if key is None: if key is None:
key = target.__name__ key = target.__name__
if key in self._dict: if key in self._dict:
raise Exception(f'{key} already exists.') raise Exception(f"{key} already exists.")
self[key] = target self[key] = target
return target return target
......
...@@ -66,12 +66,7 @@ def cache_video( ...@@ -66,12 +66,7 @@ def cache_video(
# preprocess # preprocess
tensor = tensor.clamp(min(value_range), max(value_range)) tensor = tensor.clamp(min(value_range), max(value_range))
tensor = torch.stack( tensor = torch.stack(
[ [torchvision.utils.make_grid(u, nrow=nrow, normalize=normalize, value_range=value_range) for u in tensor.unbind(2)],
torchvision.utils.make_grid(
u, nrow=nrow, normalize=normalize, value_range=value_range
)
for u in tensor.unbind(2)
],
dim=1, dim=1,
).permute(1, 2, 3, 0) ).permute(1, 2, 3, 0)
tensor = (tensor * 255).type(torch.uint8).cpu() tensor = (tensor * 255).type(torch.uint8).cpu()
......
[tool.ruff]
exclude = [".git", ".mypy_cache", ".ruff_cache", ".venv", "dist"]
target-version = "py311"
line-length = 200
indent-width = 4
lint.ignore =["F"]
[tool.ruff.format]
line-ending = "lf"
quote-style = "double"
indent-style = "space"
File mode changed from 100644 to 100755
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment