Commit a50bcc53 authored by Dongz's avatar Dongz Committed by Yang Yong(雍洋)
Browse files

add lint feature and minor fix (#7)

* [minor]: optimize dockerfile for fewer layer

* [feature]: add pre-commit lint, update readme for contribution guidance

* [minor]: fix run shell privileges

* [auto]: first lint without rule F, fix rule E

* [minor]: fix docker file error
parent 3b460075
...@@ -39,28 +39,15 @@ def init_weights(m): ...@@ -39,28 +39,15 @@ def init_weights(m):
nn.init.normal_(m.v.weight, std=m.dim**-0.5) nn.init.normal_(m.v.weight, std=m.dim**-0.5)
nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn) ** -0.5) nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn) ** -0.5)
elif isinstance(m, T5RelativeEmbedding): elif isinstance(m, T5RelativeEmbedding):
nn.init.normal_( nn.init.normal_(m.embedding.weight, std=(2 * m.num_buckets * m.num_heads) ** -0.5)
m.embedding.weight, std=(2 * m.num_buckets * m.num_heads) ** -0.5
)
class GELU(nn.Module): class GELU(nn.Module):
def forward(self, x): def forward(self, x):
return ( return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
0.5
* x
* (
1.0
+ torch.tanh(
math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))
)
)
)
class T5LayerNorm(nn.Module): class T5LayerNorm(nn.Module):
def __init__(self, dim, eps=1e-6): def __init__(self, dim, eps=1e-6):
super(T5LayerNorm, self).__init__() super(T5LayerNorm, self).__init__()
self.dim = dim self.dim = dim
...@@ -75,7 +62,6 @@ class T5LayerNorm(nn.Module): ...@@ -75,7 +62,6 @@ class T5LayerNorm(nn.Module):
class T5Attention(nn.Module): class T5Attention(nn.Module):
def __init__(self, dim, dim_attn, num_heads, dropout=0.1): def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
assert dim_attn % num_heads == 0 assert dim_attn % num_heads == 0
super(T5Attention, self).__init__() super(T5Attention, self).__init__()
...@@ -128,7 +114,6 @@ class T5Attention(nn.Module): ...@@ -128,7 +114,6 @@ class T5Attention(nn.Module):
class T5FeedForward(nn.Module): class T5FeedForward(nn.Module):
def __init__(self, dim, dim_ffn, dropout=0.1): def __init__(self, dim, dim_ffn, dropout=0.1):
super(T5FeedForward, self).__init__() super(T5FeedForward, self).__init__()
self.dim = dim self.dim = dim
...@@ -149,7 +134,6 @@ class T5FeedForward(nn.Module): ...@@ -149,7 +134,6 @@ class T5FeedForward(nn.Module):
class T5SelfAttention(nn.Module): class T5SelfAttention(nn.Module):
def __init__( def __init__(
self, self,
dim, dim,
...@@ -173,11 +157,7 @@ class T5SelfAttention(nn.Module): ...@@ -173,11 +157,7 @@ class T5SelfAttention(nn.Module):
self.attn = T5Attention(dim, dim_attn, num_heads, dropout) self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
self.norm2 = T5LayerNorm(dim) self.norm2 = T5LayerNorm(dim)
self.ffn = T5FeedForward(dim, dim_ffn, dropout) self.ffn = T5FeedForward(dim, dim_ffn, dropout)
self.pos_embedding = ( self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
None
if shared_pos
else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
)
def forward(self, x, mask=None, pos_bias=None): def forward(self, x, mask=None, pos_bias=None):
e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1)) e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
...@@ -187,7 +167,6 @@ class T5SelfAttention(nn.Module): ...@@ -187,7 +167,6 @@ class T5SelfAttention(nn.Module):
class T5CrossAttention(nn.Module): class T5CrossAttention(nn.Module):
def __init__( def __init__(
self, self,
dim, dim,
...@@ -213,27 +192,17 @@ class T5CrossAttention(nn.Module): ...@@ -213,27 +192,17 @@ class T5CrossAttention(nn.Module):
self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout) self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
self.norm3 = T5LayerNorm(dim) self.norm3 = T5LayerNorm(dim)
self.ffn = T5FeedForward(dim, dim_ffn, dropout) self.ffn = T5FeedForward(dim, dim_ffn, dropout)
self.pos_embedding = ( self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False)
None
if shared_pos
else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False)
)
def forward( def forward(self, x, mask=None, encoder_states=None, encoder_mask=None, pos_bias=None):
self, x, mask=None, encoder_states=None, encoder_mask=None, pos_bias=None
):
e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1)) e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e)) x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
x = fp16_clamp( x = fp16_clamp(x + self.cross_attn(self.norm2(x), context=encoder_states, mask=encoder_mask))
x
+ self.cross_attn(self.norm2(x), context=encoder_states, mask=encoder_mask)
)
x = fp16_clamp(x + self.ffn(self.norm3(x))) x = fp16_clamp(x + self.ffn(self.norm3(x)))
return x return x
class T5RelativeEmbedding(nn.Module): class T5RelativeEmbedding(nn.Module):
def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128): def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
super(T5RelativeEmbedding, self).__init__() super(T5RelativeEmbedding, self).__init__()
self.num_buckets = num_buckets self.num_buckets = num_buckets
...@@ -248,9 +217,7 @@ class T5RelativeEmbedding(nn.Module): ...@@ -248,9 +217,7 @@ class T5RelativeEmbedding(nn.Module):
device = self.embedding.weight.device device = self.embedding.weight.device
# rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \ # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
# torch.arange(lq).unsqueeze(1).to(device) # torch.arange(lq).unsqueeze(1).to(device)
rel_pos = torch.arange(lk, device=device).unsqueeze(0) - torch.arange( rel_pos = torch.arange(lk, device=device).unsqueeze(0) - torch.arange(lq, device=device).unsqueeze(1)
lq, device=device
).unsqueeze(1)
rel_pos = self._relative_position_bucket(rel_pos) rel_pos = self._relative_position_bucket(rel_pos)
rel_pos_embeds = self.embedding(rel_pos) rel_pos_embeds = self.embedding(rel_pos)
rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(0) # [1, N, Lq, Lk] rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(0) # [1, N, Lq, Lk]
...@@ -269,23 +236,13 @@ class T5RelativeEmbedding(nn.Module): ...@@ -269,23 +236,13 @@ class T5RelativeEmbedding(nn.Module):
# embeddings for small and large positions # embeddings for small and large positions
max_exact = num_buckets // 2 max_exact = num_buckets // 2
rel_pos_large = ( rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) / math.log(self.max_dist / max_exact) * (num_buckets - max_exact)).long()
max_exact rel_pos_large = torch.min(rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
+ (
torch.log(rel_pos.float() / max_exact)
/ math.log(self.max_dist / max_exact)
* (num_buckets - max_exact)
).long()
)
rel_pos_large = torch.min(
rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1)
)
rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large) rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
return rel_buckets return rel_buckets
class T5Encoder(nn.Module): class T5Encoder(nn.Module):
def __init__( def __init__(
self, self,
vocab, vocab,
...@@ -308,23 +265,10 @@ class T5Encoder(nn.Module): ...@@ -308,23 +265,10 @@ class T5Encoder(nn.Module):
self.shared_pos = shared_pos self.shared_pos = shared_pos
# layers # layers
self.token_embedding = ( self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim) self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True) if shared_pos else None
)
self.pos_embedding = (
T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
if shared_pos
else None
)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.blocks = nn.ModuleList( self.blocks = nn.ModuleList([T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout) for _ in range(num_layers)])
[
T5SelfAttention(
dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout
)
for _ in range(num_layers)
]
)
self.norm = T5LayerNorm(dim) self.norm = T5LayerNorm(dim)
# initialize weights # initialize weights
...@@ -342,7 +286,6 @@ class T5Encoder(nn.Module): ...@@ -342,7 +286,6 @@ class T5Encoder(nn.Module):
class T5Decoder(nn.Module): class T5Decoder(nn.Module):
def __init__( def __init__(
self, self,
vocab, vocab,
...@@ -365,23 +308,10 @@ class T5Decoder(nn.Module): ...@@ -365,23 +308,10 @@ class T5Decoder(nn.Module):
self.shared_pos = shared_pos self.shared_pos = shared_pos
# layers # layers
self.token_embedding = ( self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim) self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False) if shared_pos else None
)
self.pos_embedding = (
T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False)
if shared_pos
else None
)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.blocks = nn.ModuleList( self.blocks = nn.ModuleList([T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout) for _ in range(num_layers)])
[
T5CrossAttention(
dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout
)
for _ in range(num_layers)
]
)
self.norm = T5LayerNorm(dim) self.norm = T5LayerNorm(dim)
# initialize weights # initialize weights
...@@ -408,7 +338,6 @@ class T5Decoder(nn.Module): ...@@ -408,7 +338,6 @@ class T5Decoder(nn.Module):
class T5Model(nn.Module): class T5Model(nn.Module):
def __init__( def __init__(
self, self,
vocab_size, vocab_size,
...@@ -530,7 +459,6 @@ def umt5_xxl(**kwargs): ...@@ -530,7 +459,6 @@ def umt5_xxl(**kwargs):
class T5EncoderModel: class T5EncoderModel:
def __init__( def __init__(
self, self,
text_len, text_len,
...@@ -547,13 +475,7 @@ class T5EncoderModel: ...@@ -547,13 +475,7 @@ class T5EncoderModel:
self.tokenizer_path = tokenizer_path self.tokenizer_path = tokenizer_path
# init model # init model
model = ( model = umt5_xxl(encoder_only=True, return_tokenizer=False, dtype=dtype, device=device).eval().requires_grad_(False)
umt5_xxl(
encoder_only=True, return_tokenizer=False, dtype=dtype, device=device
)
.eval()
.requires_grad_(False)
)
logging.info(f"loading {checkpoint_path}") logging.info(f"loading {checkpoint_path}")
model.load_state_dict(torch.load(checkpoint_path, map_location="cpu", weights_only=True)) model.load_state_dict(torch.load(checkpoint_path, map_location="cpu", weights_only=True))
self.model = model self.model = model
...@@ -562,9 +484,7 @@ class T5EncoderModel: ...@@ -562,9 +484,7 @@ class T5EncoderModel:
else: else:
self.model.to(self.device) self.model.to(self.device)
# init tokenizer # init tokenizer
self.tokenizer = HuggingfaceTokenizer( self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=text_len, clean="whitespace")
name=tokenizer_path, seq_len=text_len, clean="whitespace"
)
def to_cpu(self): def to_cpu(self):
self.model = self.model.to("cpu") self.model = self.model.to("cpu")
......
...@@ -24,10 +24,7 @@ def whitespace_clean(text): ...@@ -24,10 +24,7 @@ def whitespace_clean(text):
def canonicalize(text, keep_punctuation_exact_string=None): def canonicalize(text, keep_punctuation_exact_string=None):
text = text.replace("_", " ") text = text.replace("_", " ")
if keep_punctuation_exact_string: if keep_punctuation_exact_string:
text = keep_punctuation_exact_string.join( text = keep_punctuation_exact_string.join(part.translate(str.maketrans("", "", string.punctuation)) for part in text.split(keep_punctuation_exact_string))
part.translate(str.maketrans("", "", string.punctuation))
for part in text.split(keep_punctuation_exact_string)
)
else: else:
text = text.translate(str.maketrans("", "", string.punctuation)) text = text.translate(str.maketrans("", "", string.punctuation))
text = text.lower() text = text.lower()
...@@ -36,7 +33,6 @@ def canonicalize(text, keep_punctuation_exact_string=None): ...@@ -36,7 +33,6 @@ def canonicalize(text, keep_punctuation_exact_string=None):
class HuggingfaceTokenizer: class HuggingfaceTokenizer:
def __init__(self, name, seq_len=None, clean=None, **kwargs): def __init__(self, name, seq_len=None, clean=None, **kwargs):
assert clean in (None, "whitespace", "lower", "canonicalize") assert clean in (None, "whitespace", "lower", "canonicalize")
self.name = name self.name = name
......
# Copyright 2024 The HuggingFace Team. All rights reserved. # Copyright 2024 The HuggingFace Team. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
# #
# Modified from diffusers==0.29.2 # Modified from diffusers==0.29.2
# #
# ============================================================================== # ==============================================================================
from typing import Dict, Optional, Tuple, Union from typing import Dict, Optional, Tuple, Union
from dataclasses import dataclass from dataclasses import dataclass
import torch import torch
import torch.nn as nn import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.configuration_utils import ConfigMixin, register_to_config
try: try:
# This diffusers is modified and packed in the mirror. # This diffusers is modified and packed in the mirror.
from diffusers.loaders import FromOriginalVAEMixin from diffusers.loaders import FromOriginalVAEMixin
except ImportError: except ImportError:
# Use this to be compatible with the original diffusers. # Use this to be compatible with the original diffusers.
from diffusers.loaders.single_file_model import FromOriginalModelMixin as FromOriginalVAEMixin from diffusers.loaders.single_file_model import FromOriginalModelMixin as FromOriginalVAEMixin
from diffusers.utils.accelerate_utils import apply_forward_hook from diffusers.utils.accelerate_utils import apply_forward_hook
from diffusers.models.attention_processor import ( from diffusers.models.attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS, ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS,
Attention, Attention,
AttentionProcessor, AttentionProcessor,
AttnAddedKVProcessor, AttnAddedKVProcessor,
AttnProcessor, AttnProcessor,
) )
from diffusers.models.modeling_outputs import AutoencoderKLOutput from diffusers.models.modeling_outputs import AutoencoderKLOutput
from diffusers.models.modeling_utils import ModelMixin from diffusers.models.modeling_utils import ModelMixin
from .vae import DecoderCausal3D, BaseOutput, DecoderOutput, DiagonalGaussianDistribution, EncoderCausal3D from .vae import DecoderCausal3D, BaseOutput, DecoderOutput, DiagonalGaussianDistribution, EncoderCausal3D
@dataclass @dataclass
class DecoderOutput2(BaseOutput): class DecoderOutput2(BaseOutput):
sample: torch.FloatTensor sample: torch.FloatTensor
posterior: Optional[DiagonalGaussianDistribution] = None posterior: Optional[DiagonalGaussianDistribution] = None
class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin): class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
r""" r"""
A VAE model with KL loss for encoding images/videos into latents and decoding latent representations into images/videos. A VAE model with KL loss for encoding images/videos into latents and decoding latent representations into images/videos.
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving). for all models (such as downloading or saving).
""" """
_supports_gradient_checkpointing = True _supports_gradient_checkpointing = True
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
in_channels: int = 3, in_channels: int = 3,
out_channels: int = 3, out_channels: int = 3,
down_block_types: Tuple[str] = ("DownEncoderBlockCausal3D",), down_block_types: Tuple[str] = ("DownEncoderBlockCausal3D",),
up_block_types: Tuple[str] = ("UpDecoderBlockCausal3D",), up_block_types: Tuple[str] = ("UpDecoderBlockCausal3D",),
block_out_channels: Tuple[int] = (64,), block_out_channels: Tuple[int] = (64,),
layers_per_block: int = 1, layers_per_block: int = 1,
act_fn: str = "silu", act_fn: str = "silu",
latent_channels: int = 4, latent_channels: int = 4,
norm_num_groups: int = 32, norm_num_groups: int = 32,
sample_size: int = 32, sample_size: int = 32,
sample_tsize: int = 64, sample_tsize: int = 64,
scaling_factor: float = 0.18215, scaling_factor: float = 0.18215,
force_upcast: float = True, force_upcast: float = True,
spatial_compression_ratio: int = 8, spatial_compression_ratio: int = 8,
time_compression_ratio: int = 4, time_compression_ratio: int = 4,
mid_block_add_attention: bool = True, mid_block_add_attention: bool = True,
): ):
super().__init__() super().__init__()
self.time_compression_ratio = time_compression_ratio self.time_compression_ratio = time_compression_ratio
self.encoder = EncoderCausal3D( self.encoder = EncoderCausal3D(
in_channels=in_channels, in_channels=in_channels,
out_channels=latent_channels, out_channels=latent_channels,
down_block_types=down_block_types, down_block_types=down_block_types,
block_out_channels=block_out_channels, block_out_channels=block_out_channels,
layers_per_block=layers_per_block, layers_per_block=layers_per_block,
act_fn=act_fn, act_fn=act_fn,
norm_num_groups=norm_num_groups, norm_num_groups=norm_num_groups,
double_z=True, double_z=True,
time_compression_ratio=time_compression_ratio, time_compression_ratio=time_compression_ratio,
spatial_compression_ratio=spatial_compression_ratio, spatial_compression_ratio=spatial_compression_ratio,
mid_block_add_attention=mid_block_add_attention, mid_block_add_attention=mid_block_add_attention,
) )
self.decoder = DecoderCausal3D( self.decoder = DecoderCausal3D(
in_channels=latent_channels, in_channels=latent_channels,
out_channels=out_channels, out_channels=out_channels,
up_block_types=up_block_types, up_block_types=up_block_types,
block_out_channels=block_out_channels, block_out_channels=block_out_channels,
layers_per_block=layers_per_block, layers_per_block=layers_per_block,
norm_num_groups=norm_num_groups, norm_num_groups=norm_num_groups,
act_fn=act_fn, act_fn=act_fn,
time_compression_ratio=time_compression_ratio, time_compression_ratio=time_compression_ratio,
spatial_compression_ratio=spatial_compression_ratio, spatial_compression_ratio=spatial_compression_ratio,
mid_block_add_attention=mid_block_add_attention, mid_block_add_attention=mid_block_add_attention,
) )
self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, kernel_size=1) self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, kernel_size=1)
self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, kernel_size=1) self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, kernel_size=1)
self.use_slicing = False self.use_slicing = False
self.use_spatial_tiling = False self.use_spatial_tiling = False
self.use_temporal_tiling = False self.use_temporal_tiling = False
# only relevant if vae tiling is enabled # only relevant if vae tiling is enabled
self.tile_sample_min_tsize = sample_tsize self.tile_sample_min_tsize = sample_tsize
self.tile_latent_min_tsize = sample_tsize // time_compression_ratio self.tile_latent_min_tsize = sample_tsize // time_compression_ratio
self.tile_sample_min_size = self.config.sample_size self.tile_sample_min_size = self.config.sample_size
sample_size = ( sample_size = self.config.sample_size[0] if isinstance(self.config.sample_size, (list, tuple)) else self.config.sample_size
self.config.sample_size[0] self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
if isinstance(self.config.sample_size, (list, tuple)) self.tile_overlap_factor = 0.25
else self.config.sample_size
) def _set_gradient_checkpointing(self, module, value=False):
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) if isinstance(module, (EncoderCausal3D, DecoderCausal3D)):
self.tile_overlap_factor = 0.25 module.gradient_checkpointing = value
def _set_gradient_checkpointing(self, module, value=False): def enable_temporal_tiling(self, use_tiling: bool = True):
if isinstance(module, (EncoderCausal3D, DecoderCausal3D)): self.use_temporal_tiling = use_tiling
module.gradient_checkpointing = value
def disable_temporal_tiling(self):
def enable_temporal_tiling(self, use_tiling: bool = True): self.enable_temporal_tiling(False)
self.use_temporal_tiling = use_tiling
def enable_spatial_tiling(self, use_tiling: bool = True):
def disable_temporal_tiling(self): self.use_spatial_tiling = use_tiling
self.enable_temporal_tiling(False)
def disable_spatial_tiling(self):
def enable_spatial_tiling(self, use_tiling: bool = True): self.enable_spatial_tiling(False)
self.use_spatial_tiling = use_tiling
def enable_tiling(self, use_tiling: bool = True):
def disable_spatial_tiling(self): r"""
self.enable_spatial_tiling(False) Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
def enable_tiling(self, use_tiling: bool = True): processing larger videos.
r""" """
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to self.enable_spatial_tiling(use_tiling)
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow self.enable_temporal_tiling(use_tiling)
processing larger videos.
""" def disable_tiling(self):
self.enable_spatial_tiling(use_tiling) r"""
self.enable_temporal_tiling(use_tiling) Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
def disable_tiling(self): """
r""" self.disable_spatial_tiling()
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing self.disable_temporal_tiling()
decoding in one step.
""" def enable_slicing(self):
self.disable_spatial_tiling() r"""
self.disable_temporal_tiling() Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
def enable_slicing(self): """
r""" self.use_slicing = True
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. def disable_slicing(self):
""" r"""
self.use_slicing = True Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
def disable_slicing(self): """
r""" self.use_slicing = False
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step. @property
""" # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
self.use_slicing = False def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
@property Returns:
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors `dict` of attention processors: A dictionary containing all attention processors used in the model with
def attn_processors(self) -> Dict[str, AttentionProcessor]: indexed by its weight name.
r""" """
Returns: # set recursively
`dict` of attention processors: A dictionary containing all attention processors used in the model with processors = {}
indexed by its weight name.
""" def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
# set recursively if hasattr(module, "get_processor"):
processors = {} processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): for sub_name, child in module.named_children():
if hasattr(module, "get_processor"): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
return processors
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors) # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False):
return processors r"""
Sets the attention processor to use to compute attention.
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor( Parameters:
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
): The instantiated processor class or a dictionary of processor classes that will be set as the processor
r""" for **all** `Attention` layers.
Sets the attention processor to use to compute attention.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
Parameters: processor. This is strongly recommended when setting trainable attention processors.
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor """
for **all** `Attention` layers. count = len(self.attn_processors.keys())
If `processor` is a dict, the key needs to define the path to the corresponding cross attention if isinstance(processor, dict) and len(processor) != count:
processor. This is strongly recommended when setting trainable attention processors. raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
""" f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
count = len(self.attn_processors.keys()) )
if isinstance(processor, dict) and len(processor) != count: def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
raise ValueError( if hasattr(module, "set_processor"):
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" if not isinstance(processor, dict):
f" number of attention layers: {count}. Please make sure to pass {count} processor classes." module.set_processor(processor, _remove_lora=_remove_lora)
) else:
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"): for sub_name, child in module.named_children():
if not isinstance(processor, dict): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
module.set_processor(processor, _remove_lora=_remove_lora)
else: for name, module in self.named_children():
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) fn_recursive_attn_processor(name, module, processor)
for sub_name, child in module.named_children(): # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) def set_default_attn_processor(self):
"""
for name, module in self.named_children(): Disables custom attention processors and sets the default attention implementation.
fn_recursive_attn_processor(name, module, processor) """
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor processor = AttnAddedKVProcessor()
def set_default_attn_processor(self): elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
""" processor = AttnProcessor()
Disables custom attention processors and sets the default attention implementation. else:
""" raise ValueError(f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}")
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnAddedKVProcessor() self.set_attn_processor(processor, _remove_lora=True)
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnProcessor() @apply_forward_hook
else: def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
raise ValueError( """
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" Encode a batch of images/videos into latents.
)
Args:
self.set_attn_processor(processor, _remove_lora=True) x (`torch.FloatTensor`): Input batch of images/videos.
return_dict (`bool`, *optional*, defaults to `True`):
@apply_forward_hook Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
def encode(
self, x: torch.FloatTensor, return_dict: bool = True Returns:
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: The latent representations of the encoded images/videos. If `return_dict` is True, a
""" [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
Encode a batch of images/videos into latents. """
assert len(x.shape) == 5, "The input tensor should have 5 dimensions."
Args:
x (`torch.FloatTensor`): Input batch of images/videos. if self.use_temporal_tiling and x.shape[2] > self.tile_sample_min_tsize:
return_dict (`bool`, *optional*, defaults to `True`): return self.temporal_tiled_encode(x, return_dict=return_dict)
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
if self.use_spatial_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
Returns: return self.spatial_tiled_encode(x, return_dict=return_dict)
The latent representations of the encoded images/videos. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. if self.use_slicing and x.shape[0] > 1:
""" encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
assert len(x.shape) == 5, "The input tensor should have 5 dimensions." h = torch.cat(encoded_slices)
else:
if self.use_temporal_tiling and x.shape[2] > self.tile_sample_min_tsize: h = self.encoder(x)
return self.temporal_tiled_encode(x, return_dict=return_dict)
moments = self.quant_conv(h)
if self.use_spatial_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size): posterior = DiagonalGaussianDistribution(moments)
return self.spatial_tiled_encode(x, return_dict=return_dict)
if not return_dict:
if self.use_slicing and x.shape[0] > 1: return (posterior,)
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices) return AutoencoderKLOutput(latent_dist=posterior)
else:
h = self.encoder(x) def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
assert len(z.shape) == 5, "The input tensor should have 5 dimensions."
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments) if self.use_temporal_tiling and z.shape[2] > self.tile_latent_min_tsize:
return self.temporal_tiled_decode(z, return_dict=return_dict)
if not return_dict:
return (posterior,) if self.use_spatial_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
return self.spatial_tiled_decode(z, return_dict=return_dict)
return AutoencoderKLOutput(latent_dist=posterior)
z = self.post_quant_conv(z)
def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: dec = self.decoder(z)
assert len(z.shape) == 5, "The input tensor should have 5 dimensions."
if not return_dict:
if self.use_temporal_tiling and z.shape[2] > self.tile_latent_min_tsize: return (dec,)
return self.temporal_tiled_decode(z, return_dict=return_dict)
return DecoderOutput(sample=dec)
if self.use_spatial_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
return self.spatial_tiled_decode(z, return_dict=return_dict) @apply_forward_hook
def decode(self, z: torch.FloatTensor, return_dict: bool = True, generator=None) -> Union[DecoderOutput, torch.FloatTensor]:
z = self.post_quant_conv(z) """
dec = self.decoder(z) Decode a batch of images/videos.
if not return_dict: Args:
return (dec,) z (`torch.FloatTensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
return DecoderOutput(sample=dec) Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
@apply_forward_hook Returns:
def decode( [`~models.vae.DecoderOutput`] or `tuple`:
self, z: torch.FloatTensor, return_dict: bool = True, generator=None If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
) -> Union[DecoderOutput, torch.FloatTensor]: returned.
"""
Decode a batch of images/videos. """
if self.use_slicing and z.shape[0] > 1:
Args: decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
z (`torch.FloatTensor`): Input batch of latent vectors. decoded = torch.cat(decoded_slices)
return_dict (`bool`, *optional*, defaults to `True`): else:
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. decoded = self._decode(z).sample
Returns: if not return_dict:
[`~models.vae.DecoderOutput`] or `tuple`: return (decoded,)
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned. return DecoderOutput(sample=decoded)
""" def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
if self.use_slicing and z.shape[0] > 1: blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] for y in range(blend_extent):
decoded = torch.cat(decoded_slices) b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent)
else: return b
decoded = self._decode(z).sample
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
if not return_dict: blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
return (decoded,) for x in range(blend_extent):
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent)
return DecoderOutput(sample=decoded) return b
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent) blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
for y in range(blend_extent): for x in range(blend_extent):
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent) b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (x / blend_extent)
return b return b
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: def spatial_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True, return_moments: bool = False) -> AutoencoderKLOutput:
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent) r"""Encode a batch of images/videos using a tiled encoder.
for x in range(blend_extent):
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent) When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
return b steps. This is useful to keep memory use constant regardless of image/videos size. The end result of tiled encoding is
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
blend_extent = min(a.shape[-3], b.shape[-3], blend_extent) output, but they should be much less noticeable.
for x in range(blend_extent):
b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (x / blend_extent) Args:
return b x (`torch.FloatTensor`): Input batch of images/videos.
return_dict (`bool`, *optional*, defaults to `True`):
def spatial_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True, return_moments: bool = False) -> AutoencoderKLOutput: Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
r"""Encode a batch of images/videos using a tiled encoder.
Returns:
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
steps. This is useful to keep memory use constant regardless of image/videos size. The end result of tiled encoding is If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the `tuple` is returned.
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the """
output, but they should be much less noticeable. overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
Args: row_limit = self.tile_latent_min_size - blend_extent
x (`torch.FloatTensor`): Input batch of images/videos.
return_dict (`bool`, *optional*, defaults to `True`): # Split video into tiles and encode them separately.
Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. rows = []
for i in range(0, x.shape[-2], overlap_size):
Returns: row = []
[`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`: for j in range(0, x.shape[-1], overlap_size):
If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
`tuple` is returned. tile = self.encoder(tile)
""" tile = self.quant_conv(tile)
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) row.append(tile)
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) rows.append(row)
row_limit = self.tile_latent_min_size - blend_extent result_rows = []
for i, row in enumerate(rows):
# Split video into tiles and encode them separately. result_row = []
rows = [] for j, tile in enumerate(row):
for i in range(0, x.shape[-2], overlap_size): # blend the above tile and the left tile
row = [] # to the current tile and add the current tile to the result row
for j in range(0, x.shape[-1], overlap_size): if i > 0:
tile = x[:, :, :, i: i + self.tile_sample_min_size, j: j + self.tile_sample_min_size] tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
tile = self.encoder(tile) if j > 0:
tile = self.quant_conv(tile) tile = self.blend_h(row[j - 1], tile, blend_extent)
row.append(tile) result_row.append(tile[:, :, :, :row_limit, :row_limit])
rows.append(row) result_rows.append(torch.cat(result_row, dim=-1))
result_rows = []
for i, row in enumerate(rows): moments = torch.cat(result_rows, dim=-2)
result_row = [] if return_moments:
for j, tile in enumerate(row): return moments
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row posterior = DiagonalGaussianDistribution(moments)
if i > 0: if not return_dict:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent) return (posterior,)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent) return AutoencoderKLOutput(latent_dist=posterior)
result_row.append(tile[:, :, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=-1)) def spatial_tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
r"""
moments = torch.cat(result_rows, dim=-2) Decode a batch of images/videos using a tiled decoder.
if return_moments:
return moments Args:
z (`torch.FloatTensor`): Input batch of latent vectors.
posterior = DiagonalGaussianDistribution(moments) return_dict (`bool`, *optional*, defaults to `True`):
if not return_dict: Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
return (posterior,)
Returns:
return AutoencoderKLOutput(latent_dist=posterior) [`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
def spatial_tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: returned.
r""" """
Decode a batch of images/videos using a tiled decoder. overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
Args: row_limit = self.tile_sample_min_size - blend_extent
z (`torch.FloatTensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`): # Split z into overlapping tiles and decode them separately.
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. # The tiles have an overlap to avoid seams between tiles.
rows = []
Returns: for i in range(0, z.shape[-2], overlap_size):
[`~models.vae.DecoderOutput`] or `tuple`: row = []
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is for j in range(0, z.shape[-1], overlap_size):
returned. tile = z[:, :, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
""" tile = self.post_quant_conv(tile)
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) decoded = self.decoder(tile)
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) row.append(decoded)
row_limit = self.tile_sample_min_size - blend_extent rows.append(row)
result_rows = []
# Split z into overlapping tiles and decode them separately. for i, row in enumerate(rows):
# The tiles have an overlap to avoid seams between tiles. result_row = []
rows = [] for j, tile in enumerate(row):
for i in range(0, z.shape[-2], overlap_size): # blend the above tile and the left tile
row = [] # to the current tile and add the current tile to the result row
for j in range(0, z.shape[-1], overlap_size): if i > 0:
tile = z[:, :, :, i: i + self.tile_latent_min_size, j: j + self.tile_latent_min_size] tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
tile = self.post_quant_conv(tile) if j > 0:
decoded = self.decoder(tile) tile = self.blend_h(row[j - 1], tile, blend_extent)
row.append(decoded) result_row.append(tile[:, :, :, :row_limit, :row_limit])
rows.append(row) result_rows.append(torch.cat(result_row, dim=-1))
result_rows = []
for i, row in enumerate(rows): dec = torch.cat(result_rows, dim=-2)
result_row = [] if not return_dict:
for j, tile in enumerate(row): return (dec,)
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row return DecoderOutput(sample=dec)
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent) def temporal_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
if j > 0: B, C, T, H, W = x.shape
tile = self.blend_h(row[j - 1], tile, blend_extent) overlap_size = int(self.tile_sample_min_tsize * (1 - self.tile_overlap_factor))
result_row.append(tile[:, :, :, :row_limit, :row_limit]) blend_extent = int(self.tile_latent_min_tsize * self.tile_overlap_factor)
result_rows.append(torch.cat(result_row, dim=-1)) t_limit = self.tile_latent_min_tsize - blend_extent
dec = torch.cat(result_rows, dim=-2) # Split the video into tiles and encode them separately.
if not return_dict: row = []
return (dec,) for i in range(0, T, overlap_size):
tile = x[:, :, i : i + self.tile_sample_min_tsize + 1, :, :]
return DecoderOutput(sample=dec) if self.use_spatial_tiling and (tile.shape[-1] > self.tile_sample_min_size or tile.shape[-2] > self.tile_sample_min_size):
tile = self.spatial_tiled_encode(tile, return_moments=True)
def temporal_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: else:
tile = self.encoder(tile)
B, C, T, H, W = x.shape tile = self.quant_conv(tile)
overlap_size = int(self.tile_sample_min_tsize * (1 - self.tile_overlap_factor)) if i > 0:
blend_extent = int(self.tile_latent_min_tsize * self.tile_overlap_factor) tile = tile[:, :, 1:, :, :]
t_limit = self.tile_latent_min_tsize - blend_extent row.append(tile)
result_row = []
# Split the video into tiles and encode them separately. for i, tile in enumerate(row):
row = [] if i > 0:
for i in range(0, T, overlap_size): tile = self.blend_t(row[i - 1], tile, blend_extent)
tile = x[:, :, i: i + self.tile_sample_min_tsize + 1, :, :] result_row.append(tile[:, :, :t_limit, :, :])
if self.use_spatial_tiling and (tile.shape[-1] > self.tile_sample_min_size or tile.shape[-2] > self.tile_sample_min_size): else:
tile = self.spatial_tiled_encode(tile, return_moments=True) result_row.append(tile[:, :, : t_limit + 1, :, :])
else:
tile = self.encoder(tile) moments = torch.cat(result_row, dim=2)
tile = self.quant_conv(tile) posterior = DiagonalGaussianDistribution(moments)
if i > 0:
tile = tile[:, :, 1:, :, :] if not return_dict:
row.append(tile) return (posterior,)
result_row = []
for i, tile in enumerate(row): return AutoencoderKLOutput(latent_dist=posterior)
if i > 0:
tile = self.blend_t(row[i - 1], tile, blend_extent) def temporal_tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
result_row.append(tile[:, :, :t_limit, :, :]) # Split z into overlapping tiles and decode them separately.
else:
result_row.append(tile[:, :, :t_limit + 1, :, :]) B, C, T, H, W = z.shape
overlap_size = int(self.tile_latent_min_tsize * (1 - self.tile_overlap_factor))
moments = torch.cat(result_row, dim=2) blend_extent = int(self.tile_sample_min_tsize * self.tile_overlap_factor)
posterior = DiagonalGaussianDistribution(moments) t_limit = self.tile_sample_min_tsize - blend_extent
if not return_dict: row = []
return (posterior,) for i in range(0, T, overlap_size):
tile = z[:, :, i : i + self.tile_latent_min_tsize + 1, :, :]
return AutoencoderKLOutput(latent_dist=posterior) if self.use_spatial_tiling and (tile.shape[-1] > self.tile_latent_min_size or tile.shape[-2] > self.tile_latent_min_size):
decoded = self.spatial_tiled_decode(tile, return_dict=True).sample
def temporal_tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: else:
# Split z into overlapping tiles and decode them separately. tile = self.post_quant_conv(tile)
decoded = self.decoder(tile)
B, C, T, H, W = z.shape if i > 0:
overlap_size = int(self.tile_latent_min_tsize * (1 - self.tile_overlap_factor)) decoded = decoded[:, :, 1:, :, :]
blend_extent = int(self.tile_sample_min_tsize * self.tile_overlap_factor) row.append(decoded)
t_limit = self.tile_sample_min_tsize - blend_extent result_row = []
for i, tile in enumerate(row):
row = [] if i > 0:
for i in range(0, T, overlap_size): tile = self.blend_t(row[i - 1], tile, blend_extent)
tile = z[:, :, i: i + self.tile_latent_min_tsize + 1, :, :] result_row.append(tile[:, :, :t_limit, :, :])
if self.use_spatial_tiling and (tile.shape[-1] > self.tile_latent_min_size or tile.shape[-2] > self.tile_latent_min_size): else:
decoded = self.spatial_tiled_decode(tile, return_dict=True).sample result_row.append(tile[:, :, : t_limit + 1, :, :])
else:
tile = self.post_quant_conv(tile) dec = torch.cat(result_row, dim=2)
decoded = self.decoder(tile) if not return_dict:
if i > 0: return (dec,)
decoded = decoded[:, :, 1:, :, :]
row.append(decoded) return DecoderOutput(sample=dec)
result_row = []
for i, tile in enumerate(row): def forward(
if i > 0: self,
tile = self.blend_t(row[i - 1], tile, blend_extent) sample: torch.FloatTensor,
result_row.append(tile[:, :, :t_limit, :, :]) sample_posterior: bool = False,
else: return_dict: bool = True,
result_row.append(tile[:, :, :t_limit + 1, :, :]) return_posterior: bool = False,
generator: Optional[torch.Generator] = None,
dec = torch.cat(result_row, dim=2) ) -> Union[DecoderOutput2, torch.FloatTensor]:
if not return_dict: r"""
return (dec,) Args:
sample (`torch.FloatTensor`): Input sample.
return DecoderOutput(sample=dec) sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior.
def forward( return_dict (`bool`, *optional*, defaults to `True`):
self, Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
sample: torch.FloatTensor, """
sample_posterior: bool = False, x = sample
return_dict: bool = True, posterior = self.encode(x).latent_dist
return_posterior: bool = False, if sample_posterior:
generator: Optional[torch.Generator] = None, z = posterior.sample(generator=generator)
) -> Union[DecoderOutput2, torch.FloatTensor]: else:
r""" z = posterior.mode()
Args: dec = self.decode(z).sample
sample (`torch.FloatTensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`): if not return_dict:
Whether to sample from the posterior. if return_posterior:
return_dict (`bool`, *optional*, defaults to `True`): return (dec, posterior)
Whether or not to return a [`DecoderOutput`] instead of a plain tuple. else:
""" return (dec,)
x = sample if return_posterior:
posterior = self.encode(x).latent_dist return DecoderOutput2(sample=dec, posterior=posterior)
if sample_posterior: else:
z = posterior.sample(generator=generator) return DecoderOutput2(sample=dec)
else:
z = posterior.mode() # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
dec = self.decode(z).sample def fuse_qkv_projections(self):
"""
if not return_dict: Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
if return_posterior: key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
return (dec, posterior)
else: <Tip warning={true}>
return (dec,)
if return_posterior: This API is 🧪 experimental.
return DecoderOutput2(sample=dec, posterior=posterior)
else: </Tip>
return DecoderOutput2(sample=dec) """
self.original_attn_processors = None
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
def fuse_qkv_projections(self): for _, attn_processor in self.attn_processors.items():
""" if "Added" in str(attn_processor.__class__.__name__):
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
self.original_attn_processors = self.attn_processors
<Tip warning={true}>
for module in self.modules():
This API is 🧪 experimental. if isinstance(module, Attention):
module.fuse_projections(fuse=True)
</Tip>
""" # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
self.original_attn_processors = None def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__): <Tip warning={true}>
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
This API is 🧪 experimental.
self.original_attn_processors = self.attn_processors
</Tip>
for module in self.modules():
if isinstance(module, Attention): """
module.fuse_projections(fuse=True) if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
...@@ -3,7 +3,7 @@ import torch ...@@ -3,7 +3,7 @@ import torch
from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D
class VideoEncoderKLCausal3DModel(): class VideoEncoderKLCausal3DModel:
def __init__(self, model_path, dtype, device): def __init__(self, model_path, dtype, device):
self.model_path = model_path self.model_path = model_path
self.dtype = dtype self.dtype = dtype
...@@ -11,10 +11,10 @@ class VideoEncoderKLCausal3DModel(): ...@@ -11,10 +11,10 @@ class VideoEncoderKLCausal3DModel():
self.load() self.load()
def load(self): def load(self):
self.vae_path = os.path.join(self.model_path, 'hunyuan-video-t2v-720p/vae') self.vae_path = os.path.join(self.model_path, "hunyuan-video-t2v-720p/vae")
config = AutoencoderKLCausal3D.load_config(self.vae_path) config = AutoencoderKLCausal3D.load_config(self.vae_path)
self.model = AutoencoderKLCausal3D.from_config(config) self.model = AutoencoderKLCausal3D.from_config(config)
ckpt = torch.load(os.path.join(self.vae_path, 'pytorch_model.pt'), map_location='cpu', weights_only=True) ckpt = torch.load(os.path.join(self.vae_path, "pytorch_model.pt"), map_location="cpu", weights_only=True)
self.model.load_state_dict(ckpt) self.model.load_state_dict(ckpt)
self.model = self.model.to(dtype=self.dtype, device=self.device) self.model = self.model.to(dtype=self.dtype, device=self.device)
self.model.requires_grad_(False) self.model.requires_grad_(False)
...@@ -32,14 +32,13 @@ class VideoEncoderKLCausal3DModel(): ...@@ -32,14 +32,13 @@ class VideoEncoderKLCausal3DModel():
latents = latents / self.model.config.scaling_factor latents = latents / self.model.config.scaling_factor
latents = latents.to(dtype=self.dtype, device=torch.device("cuda")) latents = latents.to(dtype=self.dtype, device=torch.device("cuda"))
self.model.enable_tiling() self.model.enable_tiling()
image = self.model.decode( image = self.model.decode(latents, return_dict=False, generator=generator)[0]
latents, return_dict=False, generator=generator
)[0]
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().float() image = image.cpu().float()
if args.cpu_offload: if args.cpu_offload:
self.to_cpu() self.to_cpu()
return image return image
if __name__ == "__main__": if __name__ == "__main__":
vae_model = VideoEncoderKLCausal3DModel("/mnt/nvme0/yongyang/projects/hy/new/HunyuanVideo/ckpts", dtype=torch.float16, device=torch.device("cuda")) vae_model = VideoEncoderKLCausal3DModel("/mnt/nvme0/yongyang/projects/hy/new/HunyuanVideo/ckpts", dtype=torch.float16, device=torch.device("cuda"))
# Copyright 2024 The HuggingFace Team. All rights reserved. # Copyright 2024 The HuggingFace Team. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
# #
# Modified from diffusers==0.29.2 # Modified from diffusers==0.29.2
# #
# ============================================================================== # ==============================================================================
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from einops import rearrange from einops import rearrange
from diffusers.utils import logging from diffusers.utils import logging
from diffusers.models.activations import get_activation from diffusers.models.activations import get_activation
from diffusers.models.attention_processor import SpatialNorm from diffusers.models.attention_processor import SpatialNorm
from diffusers.models.attention_processor import Attention from diffusers.models.attention_processor import Attention
from diffusers.models.normalization import AdaGroupNorm from diffusers.models.normalization import AdaGroupNorm
from diffusers.models.normalization import RMSNorm from diffusers.models.normalization import RMSNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def prepare_causal_attention_mask_ori(n_frame: int, n_hw: int, dtype, device, batch_size: int = None): def prepare_causal_attention_mask_ori(n_frame: int, n_hw: int, dtype, device, batch_size: int = None):
seq_len = n_frame * n_hw seq_len = n_frame * n_hw
mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device) mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device)
for i in range(seq_len): for i in range(seq_len):
i_frame = i // n_hw i_frame = i // n_hw
mask[i, : (i_frame + 1) * n_hw] = 0 mask[i, : (i_frame + 1) * n_hw] = 0
if batch_size is not None: if batch_size is not None:
mask = mask.unsqueeze(0).expand(batch_size, -1, -1) mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
return mask return mask
def prepare_causal_attention_mask(n_frame: int, n_hw: int, dtype, device, batch_size: int = None): def prepare_causal_attention_mask(n_frame: int, n_hw: int, dtype, device, batch_size: int = None):
seq_len = n_frame * n_hw seq_len = n_frame * n_hw
mask = torch.full((n_frame, n_frame, n_hw, n_hw), float("-inf"), dtype=dtype, device=device) mask = torch.full((n_frame, n_frame, n_hw, n_hw), float("-inf"), dtype=dtype, device=device)
# mask = mask.reshape(n_frame, n_frame, n_hw, n_hw) # mask = mask.reshape(n_frame, n_frame, n_hw, n_hw)
idx_arr = torch.tril(torch.ones(n_frame, n_frame, dtype=dtype, device=device)) idx_arr = torch.tril(torch.ones(n_frame, n_frame, dtype=dtype, device=device))
idx_arr = idx_arr > torch.zeros_like(idx_arr) idx_arr = idx_arr > torch.zeros_like(idx_arr)
for i in range(n_frame): for i in range(n_frame):
for j in range(n_frame): for j in range(n_frame):
if idx_arr[i,j]: if idx_arr[i, j]:
mask[i, j] = torch.zeros(n_hw, n_hw, dtype=dtype, device=device) mask[i, j] = torch.zeros(n_hw, n_hw, dtype=dtype, device=device)
# mask[idx_arr] = torch.zeros(n_hw, n_hw, dtype=dtype, device=device) # mask[idx_arr] = torch.zeros(n_hw, n_hw, dtype=dtype, device=device)
mask = mask.view(n_frame, -1, n_hw).transpose(1, 0).reshape(seq_len, -1).transpose(1,0) mask = mask.view(n_frame, -1, n_hw).transpose(1, 0).reshape(seq_len, -1).transpose(1, 0)
if batch_size is not None: if batch_size is not None:
mask = mask.unsqueeze(0).expand(batch_size, -1, -1) mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
return mask.to(device) return mask.to(device)
class CausalConv3d(nn.Module):
class CausalConv3d(nn.Module): """
""" Implements a causal 3D convolution layer where each position only depends on previous timesteps and current spatial locations.
Implements a causal 3D convolution layer where each position only depends on previous timesteps and current spatial locations. This maintains temporal causality in video generation tasks.
This maintains temporal causality in video generation tasks. """
"""
def __init__(
def __init__( self,
self, chan_in,
chan_in, chan_out,
chan_out, kernel_size: Union[int, Tuple[int, int, int]],
kernel_size: Union[int, Tuple[int, int, int]], stride: Union[int, Tuple[int, int, int]] = 1,
stride: Union[int, Tuple[int, int, int]] = 1, dilation: Union[int, Tuple[int, int, int]] = 1,
dilation: Union[int, Tuple[int, int, int]] = 1, pad_mode="replicate",
pad_mode='replicate', **kwargs,
**kwargs ):
): super().__init__()
super().__init__()
self.pad_mode = pad_mode
self.pad_mode = pad_mode padding = (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size - 1, 0) # W, H, T
padding = (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size - 1, 0) # W, H, T self.time_causal_padding = padding
self.time_causal_padding = padding
self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
def forward(self, x):
def forward(self, x): x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode) return self.conv(x)
return self.conv(x)
class UpsampleCausal3D(nn.Module):
class UpsampleCausal3D(nn.Module): """
""" A 3D upsampling layer with an optional convolution.
A 3D upsampling layer with an optional convolution. """
"""
def __init__(
def __init__( self,
self, channels: int,
channels: int, use_conv: bool = False,
use_conv: bool = False, use_conv_transpose: bool = False,
use_conv_transpose: bool = False, out_channels: Optional[int] = None,
out_channels: Optional[int] = None, name: str = "conv",
name: str = "conv", kernel_size: Optional[int] = None,
kernel_size: Optional[int] = None, padding=1,
padding=1, norm_type=None,
norm_type=None, eps=None,
eps=None, elementwise_affine=None,
elementwise_affine=None, bias=True,
bias=True, interpolate=True,
interpolate=True, upsample_factor=(2, 2, 2),
upsample_factor=(2, 2, 2), ):
): super().__init__()
super().__init__() self.channels = channels
self.channels = channels self.out_channels = out_channels or channels
self.out_channels = out_channels or channels self.use_conv = use_conv
self.use_conv = use_conv self.use_conv_transpose = use_conv_transpose
self.use_conv_transpose = use_conv_transpose self.name = name
self.name = name self.interpolate = interpolate
self.interpolate = interpolate self.upsample_factor = upsample_factor
self.upsample_factor = upsample_factor
if norm_type == "ln_norm":
if norm_type == "ln_norm": self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
self.norm = nn.LayerNorm(channels, eps, elementwise_affine) elif norm_type == "rms_norm":
elif norm_type == "rms_norm": self.norm = RMSNorm(channels, eps, elementwise_affine)
self.norm = RMSNorm(channels, eps, elementwise_affine) elif norm_type is None:
elif norm_type is None: self.norm = None
self.norm = None else:
else: raise ValueError(f"unknown norm_type: {norm_type}")
raise ValueError(f"unknown norm_type: {norm_type}")
conv = None
conv = None if use_conv_transpose:
if use_conv_transpose: raise NotImplementedError
raise NotImplementedError elif use_conv:
elif use_conv: if kernel_size is None:
if kernel_size is None: kernel_size = 3
kernel_size = 3 conv = CausalConv3d(self.channels, self.out_channels, kernel_size=kernel_size, bias=bias)
conv = CausalConv3d(self.channels, self.out_channels, kernel_size=kernel_size, bias=bias)
if name == "conv":
if name == "conv": self.conv = conv
self.conv = conv else:
else: self.Conv2d_0 = conv
self.Conv2d_0 = conv
def forward(
def forward( self,
self, hidden_states: torch.FloatTensor,
hidden_states: torch.FloatTensor, output_size: Optional[int] = None,
output_size: Optional[int] = None, scale: float = 1.0,
scale: float = 1.0, ) -> torch.FloatTensor:
) -> torch.FloatTensor: assert hidden_states.shape[1] == self.channels
assert hidden_states.shape[1] == self.channels
if self.norm is not None:
if self.norm is not None: raise NotImplementedError
raise NotImplementedError
if self.use_conv_transpose:
if self.use_conv_transpose: return self.conv(hidden_states)
return self.conv(hidden_states)
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 dtype = hidden_states.dtype
dtype = hidden_states.dtype if dtype == torch.bfloat16:
if dtype == torch.bfloat16: hidden_states = hidden_states.to(torch.float32)
hidden_states = hidden_states.to(torch.float32)
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 if hidden_states.shape[0] >= 64:
if hidden_states.shape[0] >= 64: hidden_states = hidden_states.contiguous()
hidden_states = hidden_states.contiguous()
# if `output_size` is passed we force the interpolation output
# if `output_size` is passed we force the interpolation output # size and do not make use of `scale_factor=2`
# size and do not make use of `scale_factor=2` if self.interpolate:
if self.interpolate: B, C, T, H, W = hidden_states.shape
B, C, T, H, W = hidden_states.shape first_h, other_h = hidden_states.split((1, T - 1), dim=2)
first_h, other_h = hidden_states.split((1, T - 1), dim=2) if output_size is None:
if output_size is None: if T > 1:
if T > 1: other_h = F.interpolate(other_h, scale_factor=self.upsample_factor, mode="nearest")
other_h = F.interpolate(other_h, scale_factor=self.upsample_factor, mode="nearest")
# first_h = first_h.squeeze(2)
# first_h = first_h.squeeze(2) first_h = first_h.view(B, C, H, W)
first_h = first_h.view(B, C, H, W) first_h = F.interpolate(first_h, scale_factor=self.upsample_factor[1:], mode="nearest")
first_h = F.interpolate(first_h, scale_factor=self.upsample_factor[1:], mode="nearest") first_h = first_h.unsqueeze(2)
first_h = first_h.unsqueeze(2) else:
else: raise NotImplementedError
raise NotImplementedError
if T > 1:
if T > 1: hidden_states = torch.cat((first_h, other_h), dim=2)
hidden_states = torch.cat((first_h, other_h), dim=2) else:
else: hidden_states = first_h
hidden_states = first_h
# If the input is bfloat16, we cast back to bfloat16
# If the input is bfloat16, we cast back to bfloat16 if dtype == torch.bfloat16:
if dtype == torch.bfloat16: hidden_states = hidden_states.to(dtype)
hidden_states = hidden_states.to(dtype)
if self.use_conv:
if self.use_conv: if self.name == "conv":
if self.name == "conv": hidden_states = self.conv(hidden_states)
hidden_states = self.conv(hidden_states) else:
else: hidden_states = self.Conv2d_0(hidden_states)
hidden_states = self.Conv2d_0(hidden_states)
return hidden_states
return hidden_states
class DownsampleCausal3D(nn.Module):
class DownsampleCausal3D(nn.Module): """
""" A 3D downsampling layer with an optional convolution.
A 3D downsampling layer with an optional convolution. """
"""
def __init__(
def __init__( self,
self, channels: int,
channels: int, use_conv: bool = False,
use_conv: bool = False, out_channels: Optional[int] = None,
out_channels: Optional[int] = None, padding: int = 1,
padding: int = 1, name: str = "conv",
name: str = "conv", kernel_size=3,
kernel_size=3, norm_type=None,
norm_type=None, eps=None,
eps=None, elementwise_affine=None,
elementwise_affine=None, bias=True,
bias=True, stride=2,
stride=2, ):
): super().__init__()
super().__init__() self.channels = channels
self.channels = channels self.out_channels = out_channels or channels
self.out_channels = out_channels or channels self.use_conv = use_conv
self.use_conv = use_conv self.padding = padding
self.padding = padding stride = stride
stride = stride self.name = name
self.name = name
if norm_type == "ln_norm":
if norm_type == "ln_norm": self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
self.norm = nn.LayerNorm(channels, eps, elementwise_affine) elif norm_type == "rms_norm":
elif norm_type == "rms_norm": self.norm = RMSNorm(channels, eps, elementwise_affine)
self.norm = RMSNorm(channels, eps, elementwise_affine) elif norm_type is None:
elif norm_type is None: self.norm = None
self.norm = None else:
else: raise ValueError(f"unknown norm_type: {norm_type}")
raise ValueError(f"unknown norm_type: {norm_type}")
if use_conv:
if use_conv: conv = CausalConv3d(self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, bias=bias)
conv = CausalConv3d( else:
self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, bias=bias raise NotImplementedError
)
else: if name == "conv":
raise NotImplementedError self.Conv2d_0 = conv
self.conv = conv
if name == "conv": elif name == "Conv2d_0":
self.Conv2d_0 = conv self.conv = conv
self.conv = conv else:
elif name == "Conv2d_0": self.conv = conv
self.conv = conv
else: def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
self.conv = conv assert hidden_states.shape[1] == self.channels
def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor: if self.norm is not None:
assert hidden_states.shape[1] == self.channels hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
if self.norm is not None: assert hidden_states.shape[1] == self.channels
hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
hidden_states = self.conv(hidden_states)
assert hidden_states.shape[1] == self.channels
return hidden_states
hidden_states = self.conv(hidden_states)
return hidden_states class ResnetBlockCausal3D(nn.Module):
r"""
A Resnet block.
class ResnetBlockCausal3D(nn.Module): """
r"""
A Resnet block. def __init__(
""" self,
*,
def __init__( in_channels: int,
self, out_channels: Optional[int] = None,
*, conv_shortcut: bool = False,
in_channels: int, dropout: float = 0.0,
out_channels: Optional[int] = None, temb_channels: int = 512,
conv_shortcut: bool = False, groups: int = 32,
dropout: float = 0.0, groups_out: Optional[int] = None,
temb_channels: int = 512, pre_norm: bool = True,
groups: int = 32, eps: float = 1e-6,
groups_out: Optional[int] = None, non_linearity: str = "swish",
pre_norm: bool = True, skip_time_act: bool = False,
eps: float = 1e-6, # default, scale_shift, ada_group, spatial
non_linearity: str = "swish", time_embedding_norm: str = "default",
skip_time_act: bool = False, kernel: Optional[torch.FloatTensor] = None,
# default, scale_shift, ada_group, spatial output_scale_factor: float = 1.0,
time_embedding_norm: str = "default", use_in_shortcut: Optional[bool] = None,
kernel: Optional[torch.FloatTensor] = None, up: bool = False,
output_scale_factor: float = 1.0, down: bool = False,
use_in_shortcut: Optional[bool] = None, conv_shortcut_bias: bool = True,
up: bool = False, conv_3d_out_channels: Optional[int] = None,
down: bool = False, ):
conv_shortcut_bias: bool = True, super().__init__()
conv_3d_out_channels: Optional[int] = None, self.pre_norm = pre_norm
): self.pre_norm = True
super().__init__() self.in_channels = in_channels
self.pre_norm = pre_norm out_channels = in_channels if out_channels is None else out_channels
self.pre_norm = True self.out_channels = out_channels
self.in_channels = in_channels self.use_conv_shortcut = conv_shortcut
out_channels = in_channels if out_channels is None else out_channels self.up = up
self.out_channels = out_channels self.down = down
self.use_conv_shortcut = conv_shortcut self.output_scale_factor = output_scale_factor
self.up = up self.time_embedding_norm = time_embedding_norm
self.down = down self.skip_time_act = skip_time_act
self.output_scale_factor = output_scale_factor
self.time_embedding_norm = time_embedding_norm linear_cls = nn.Linear
self.skip_time_act = skip_time_act
if groups_out is None:
linear_cls = nn.Linear groups_out = groups
if groups_out is None: if self.time_embedding_norm == "ada_group":
groups_out = groups self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
elif self.time_embedding_norm == "spatial":
if self.time_embedding_norm == "ada_group": self.norm1 = SpatialNorm(in_channels, temb_channels)
self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps) else:
elif self.time_embedding_norm == "spatial": self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
self.norm1 = SpatialNorm(in_channels, temb_channels)
else: self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, stride=1)
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
if temb_channels is not None:
self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, stride=1) if self.time_embedding_norm == "default":
self.time_emb_proj = linear_cls(temb_channels, out_channels)
if temb_channels is not None: elif self.time_embedding_norm == "scale_shift":
if self.time_embedding_norm == "default": self.time_emb_proj = linear_cls(temb_channels, 2 * out_channels)
self.time_emb_proj = linear_cls(temb_channels, out_channels) elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
elif self.time_embedding_norm == "scale_shift": self.time_emb_proj = None
self.time_emb_proj = linear_cls(temb_channels, 2 * out_channels) else:
elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": raise ValueError(f"Unknown time_embedding_norm : {self.time_embedding_norm} ")
self.time_emb_proj = None else:
else: self.time_emb_proj = None
raise ValueError(f"Unknown time_embedding_norm : {self.time_embedding_norm} ")
else: if self.time_embedding_norm == "ada_group":
self.time_emb_proj = None self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
elif self.time_embedding_norm == "spatial":
if self.time_embedding_norm == "ada_group": self.norm2 = SpatialNorm(out_channels, temb_channels)
self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps) else:
elif self.time_embedding_norm == "spatial": self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
self.norm2 = SpatialNorm(out_channels, temb_channels)
else: self.dropout = torch.nn.Dropout(dropout)
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) conv_3d_out_channels = conv_3d_out_channels or out_channels
self.conv2 = CausalConv3d(out_channels, conv_3d_out_channels, kernel_size=3, stride=1)
self.dropout = torch.nn.Dropout(dropout)
conv_3d_out_channels = conv_3d_out_channels or out_channels self.nonlinearity = get_activation(non_linearity)
self.conv2 = CausalConv3d(out_channels, conv_3d_out_channels, kernel_size=3, stride=1)
self.upsample = self.downsample = None
self.nonlinearity = get_activation(non_linearity) if self.up:
self.upsample = UpsampleCausal3D(in_channels, use_conv=False)
self.upsample = self.downsample = None elif self.down:
if self.up: self.downsample = DownsampleCausal3D(in_channels, use_conv=False, name="op")
self.upsample = UpsampleCausal3D(in_channels, use_conv=False)
elif self.down: self.use_in_shortcut = self.in_channels != conv_3d_out_channels if use_in_shortcut is None else use_in_shortcut
self.downsample = DownsampleCausal3D(in_channels, use_conv=False, name="op")
self.conv_shortcut = None
self.use_in_shortcut = self.in_channels != conv_3d_out_channels if use_in_shortcut is None else use_in_shortcut if self.use_in_shortcut:
self.conv_shortcut = CausalConv3d(
self.conv_shortcut = None in_channels,
if self.use_in_shortcut: conv_3d_out_channels,
self.conv_shortcut = CausalConv3d( kernel_size=1,
in_channels, stride=1,
conv_3d_out_channels, bias=conv_shortcut_bias,
kernel_size=1, )
stride=1,
bias=conv_shortcut_bias, def forward(
) self,
input_tensor: torch.FloatTensor,
def forward( temb: torch.FloatTensor,
self, scale: float = 1.0,
input_tensor: torch.FloatTensor, ) -> torch.FloatTensor:
temb: torch.FloatTensor, hidden_states = input_tensor
scale: float = 1.0,
) -> torch.FloatTensor: if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
hidden_states = input_tensor hidden_states = self.norm1(hidden_states, temb)
else:
if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": hidden_states = self.norm1(hidden_states)
hidden_states = self.norm1(hidden_states, temb)
else: hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.norm1(hidden_states)
if self.upsample is not None:
hidden_states = self.nonlinearity(hidden_states) # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
if hidden_states.shape[0] >= 64:
if self.upsample is not None: input_tensor = input_tensor.contiguous()
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 hidden_states = hidden_states.contiguous()
if hidden_states.shape[0] >= 64: input_tensor = self.upsample(input_tensor, scale=scale)
input_tensor = input_tensor.contiguous() hidden_states = self.upsample(hidden_states, scale=scale)
hidden_states = hidden_states.contiguous() elif self.downsample is not None:
input_tensor = ( input_tensor = self.downsample(input_tensor, scale=scale)
self.upsample(input_tensor, scale=scale) hidden_states = self.downsample(hidden_states, scale=scale)
)
hidden_states = ( hidden_states = self.conv1(hidden_states)
self.upsample(hidden_states, scale=scale)
) if self.time_emb_proj is not None:
elif self.downsample is not None: if not self.skip_time_act:
input_tensor = ( temb = self.nonlinearity(temb)
self.downsample(input_tensor, scale=scale) temb = self.time_emb_proj(temb, scale)[:, :, None, None]
)
hidden_states = ( if temb is not None and self.time_embedding_norm == "default":
self.downsample(hidden_states, scale=scale) hidden_states = hidden_states + temb
)
if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
hidden_states = self.conv1(hidden_states) hidden_states = self.norm2(hidden_states, temb)
else:
if self.time_emb_proj is not None: hidden_states = self.norm2(hidden_states)
if not self.skip_time_act:
temb = self.nonlinearity(temb) if temb is not None and self.time_embedding_norm == "scale_shift":
temb = ( scale, shift = torch.chunk(temb, 2, dim=1)
self.time_emb_proj(temb, scale)[:, :, None, None] hidden_states = hidden_states * (1 + scale) + shift
)
hidden_states = self.nonlinearity(hidden_states)
if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
hidden_states = self.norm2(hidden_states, temb) if self.conv_shortcut is not None:
else: input_tensor = self.conv_shortcut(input_tensor)
hidden_states = self.norm2(hidden_states)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
if temb is not None and self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1) return output_tensor
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.nonlinearity(hidden_states) def get_down_block3d(
down_block_type: str,
hidden_states = self.dropout(hidden_states) num_layers: int,
hidden_states = self.conv2(hidden_states) in_channels: int,
out_channels: int,
if self.conv_shortcut is not None: temb_channels: int,
input_tensor = ( add_downsample: bool,
self.conv_shortcut(input_tensor) downsample_stride: int,
) resnet_eps: float,
resnet_act_fn: str,
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor transformer_layers_per_block: int = 1,
num_attention_heads: Optional[int] = None,
return output_tensor resnet_groups: Optional[int] = None,
cross_attention_dim: Optional[int] = None,
downsample_padding: Optional[int] = None,
def get_down_block3d( dual_cross_attention: bool = False,
down_block_type: str, use_linear_projection: bool = False,
num_layers: int, only_cross_attention: bool = False,
in_channels: int, upcast_attention: bool = False,
out_channels: int, resnet_time_scale_shift: str = "default",
temb_channels: int, attention_type: str = "default",
add_downsample: bool, resnet_skip_time_act: bool = False,
downsample_stride: int, resnet_out_scale_factor: float = 1.0,
resnet_eps: float, cross_attention_norm: Optional[str] = None,
resnet_act_fn: str, attention_head_dim: Optional[int] = None,
transformer_layers_per_block: int = 1, downsample_type: Optional[str] = None,
num_attention_heads: Optional[int] = None, dropout: float = 0.0,
resnet_groups: Optional[int] = None, ):
cross_attention_dim: Optional[int] = None, # If attn head dim is not defined, we default it to the number of heads
downsample_padding: Optional[int] = None, if attention_head_dim is None:
dual_cross_attention: bool = False, logger.warn(f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}.")
use_linear_projection: bool = False, attention_head_dim = num_attention_heads
only_cross_attention: bool = False,
upcast_attention: bool = False, down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
resnet_time_scale_shift: str = "default", if down_block_type == "DownEncoderBlockCausal3D":
attention_type: str = "default", return DownEncoderBlockCausal3D(
resnet_skip_time_act: bool = False, num_layers=num_layers,
resnet_out_scale_factor: float = 1.0, in_channels=in_channels,
cross_attention_norm: Optional[str] = None, out_channels=out_channels,
attention_head_dim: Optional[int] = None, dropout=dropout,
downsample_type: Optional[str] = None, add_downsample=add_downsample,
dropout: float = 0.0, downsample_stride=downsample_stride,
): resnet_eps=resnet_eps,
# If attn head dim is not defined, we default it to the number of heads resnet_act_fn=resnet_act_fn,
if attention_head_dim is None: resnet_groups=resnet_groups,
logger.warn( downsample_padding=downsample_padding,
f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}." resnet_time_scale_shift=resnet_time_scale_shift,
) )
attention_head_dim = num_attention_heads raise ValueError(f"{down_block_type} does not exist.")
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
if down_block_type == "DownEncoderBlockCausal3D": def get_up_block3d(
return DownEncoderBlockCausal3D( up_block_type: str,
num_layers=num_layers, num_layers: int,
in_channels=in_channels, in_channels: int,
out_channels=out_channels, out_channels: int,
dropout=dropout, prev_output_channel: int,
add_downsample=add_downsample, temb_channels: int,
downsample_stride=downsample_stride, add_upsample: bool,
resnet_eps=resnet_eps, upsample_scale_factor: Tuple,
resnet_act_fn=resnet_act_fn, resnet_eps: float,
resnet_groups=resnet_groups, resnet_act_fn: str,
downsample_padding=downsample_padding, resolution_idx: Optional[int] = None,
resnet_time_scale_shift=resnet_time_scale_shift, transformer_layers_per_block: int = 1,
) num_attention_heads: Optional[int] = None,
raise ValueError(f"{down_block_type} does not exist.") resnet_groups: Optional[int] = None,
cross_attention_dim: Optional[int] = None,
dual_cross_attention: bool = False,
def get_up_block3d( use_linear_projection: bool = False,
up_block_type: str, only_cross_attention: bool = False,
num_layers: int, upcast_attention: bool = False,
in_channels: int, resnet_time_scale_shift: str = "default",
out_channels: int, attention_type: str = "default",
prev_output_channel: int, resnet_skip_time_act: bool = False,
temb_channels: int, resnet_out_scale_factor: float = 1.0,
add_upsample: bool, cross_attention_norm: Optional[str] = None,
upsample_scale_factor: Tuple, attention_head_dim: Optional[int] = None,
resnet_eps: float, upsample_type: Optional[str] = None,
resnet_act_fn: str, dropout: float = 0.0,
resolution_idx: Optional[int] = None, ) -> nn.Module:
transformer_layers_per_block: int = 1, # If attn head dim is not defined, we default it to the number of heads
num_attention_heads: Optional[int] = None, if attention_head_dim is None:
resnet_groups: Optional[int] = None, logger.warn(f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}.")
cross_attention_dim: Optional[int] = None, attention_head_dim = num_attention_heads
dual_cross_attention: bool = False,
use_linear_projection: bool = False, up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
only_cross_attention: bool = False, if up_block_type == "UpDecoderBlockCausal3D":
upcast_attention: bool = False, return UpDecoderBlockCausal3D(
resnet_time_scale_shift: str = "default", num_layers=num_layers,
attention_type: str = "default", in_channels=in_channels,
resnet_skip_time_act: bool = False, out_channels=out_channels,
resnet_out_scale_factor: float = 1.0, resolution_idx=resolution_idx,
cross_attention_norm: Optional[str] = None, dropout=dropout,
attention_head_dim: Optional[int] = None, add_upsample=add_upsample,
upsample_type: Optional[str] = None, upsample_scale_factor=upsample_scale_factor,
dropout: float = 0.0, resnet_eps=resnet_eps,
) -> nn.Module: resnet_act_fn=resnet_act_fn,
# If attn head dim is not defined, we default it to the number of heads resnet_groups=resnet_groups,
if attention_head_dim is None: resnet_time_scale_shift=resnet_time_scale_shift,
logger.warn( temb_channels=temb_channels,
f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}." )
) raise ValueError(f"{up_block_type} does not exist.")
attention_head_dim = num_attention_heads
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type class UNetMidBlockCausal3D(nn.Module):
if up_block_type == "UpDecoderBlockCausal3D": """
return UpDecoderBlockCausal3D( A 3D UNet mid-block [`UNetMidBlockCausal3D`] with multiple residual blocks and optional attention blocks.
num_layers=num_layers, """
in_channels=in_channels,
out_channels=out_channels, def __init__(
resolution_idx=resolution_idx, self,
dropout=dropout, in_channels: int,
add_upsample=add_upsample, temb_channels: int,
upsample_scale_factor=upsample_scale_factor, dropout: float = 0.0,
resnet_eps=resnet_eps, num_layers: int = 1,
resnet_act_fn=resnet_act_fn, resnet_eps: float = 1e-6,
resnet_groups=resnet_groups, resnet_time_scale_shift: str = "default", # default, spatial
resnet_time_scale_shift=resnet_time_scale_shift, resnet_act_fn: str = "swish",
temb_channels=temb_channels, resnet_groups: int = 32,
) attn_groups: Optional[int] = None,
raise ValueError(f"{up_block_type} does not exist.") resnet_pre_norm: bool = True,
add_attention: bool = True,
attention_head_dim: int = 1,
class UNetMidBlockCausal3D(nn.Module): output_scale_factor: float = 1.0,
""" ):
A 3D UNet mid-block [`UNetMidBlockCausal3D`] with multiple residual blocks and optional attention blocks. super().__init__()
""" resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
self.add_attention = add_attention
def __init__(
self, if attn_groups is None:
in_channels: int, attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None
temb_channels: int,
dropout: float = 0.0, # there is always at least one resnet
num_layers: int = 1, resnets = [
resnet_eps: float = 1e-6, ResnetBlockCausal3D(
resnet_time_scale_shift: str = "default", # default, spatial in_channels=in_channels,
resnet_act_fn: str = "swish", out_channels=in_channels,
resnet_groups: int = 32, temb_channels=temb_channels,
attn_groups: Optional[int] = None, eps=resnet_eps,
resnet_pre_norm: bool = True, groups=resnet_groups,
add_attention: bool = True, dropout=dropout,
attention_head_dim: int = 1, time_embedding_norm=resnet_time_scale_shift,
output_scale_factor: float = 1.0, non_linearity=resnet_act_fn,
): output_scale_factor=output_scale_factor,
super().__init__() pre_norm=resnet_pre_norm,
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) )
self.add_attention = add_attention ]
attentions = []
if attn_groups is None:
attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None if attention_head_dim is None:
logger.warn(f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}.")
# there is always at least one resnet attention_head_dim = in_channels
resnets = [
ResnetBlockCausal3D( for _ in range(num_layers):
in_channels=in_channels, if self.add_attention:
out_channels=in_channels, attentions.append(
temb_channels=temb_channels, Attention(
eps=resnet_eps, in_channels,
groups=resnet_groups, heads=in_channels // attention_head_dim,
dropout=dropout, dim_head=attention_head_dim,
time_embedding_norm=resnet_time_scale_shift, rescale_output_factor=output_scale_factor,
non_linearity=resnet_act_fn, eps=resnet_eps,
output_scale_factor=output_scale_factor, norm_num_groups=attn_groups,
pre_norm=resnet_pre_norm, spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
) residual_connection=True,
] bias=True,
attentions = [] upcast_softmax=True,
_from_deprecated_attn_block=True,
if attention_head_dim is None: )
logger.warn( )
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}." else:
) attentions.append(None)
attention_head_dim = in_channels
resnets.append(
for _ in range(num_layers): ResnetBlockCausal3D(
if self.add_attention: in_channels=in_channels,
attentions.append( out_channels=in_channels,
Attention( temb_channels=temb_channels,
in_channels, eps=resnet_eps,
heads=in_channels // attention_head_dim, groups=resnet_groups,
dim_head=attention_head_dim, dropout=dropout,
rescale_output_factor=output_scale_factor, time_embedding_norm=resnet_time_scale_shift,
eps=resnet_eps, non_linearity=resnet_act_fn,
norm_num_groups=attn_groups, output_scale_factor=output_scale_factor,
spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None, pre_norm=resnet_pre_norm,
residual_connection=True, )
bias=True, )
upcast_softmax=True,
_from_deprecated_attn_block=True, self.attentions = nn.ModuleList(attentions)
) self.resnets = nn.ModuleList(resnets)
)
else: def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
attentions.append(None) hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
resnets.append( if attn is not None:
ResnetBlockCausal3D( B, C, T, H, W = hidden_states.shape
in_channels=in_channels, hidden_states = rearrange(hidden_states, "b c f h w -> b (f h w) c")
out_channels=in_channels, attention_mask = prepare_causal_attention_mask(T, H * W, hidden_states.dtype, hidden_states.device, batch_size=B)
temb_channels=temb_channels, hidden_states = attn(hidden_states, temb=temb, attention_mask=attention_mask)
eps=resnet_eps, hidden_states = rearrange(hidden_states, "b (f h w) c -> b c f h w", f=T, h=H, w=W)
groups=resnet_groups, hidden_states = resnet(hidden_states, temb)
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift, return hidden_states
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm, class DownEncoderBlockCausal3D(nn.Module):
) def __init__(
) self,
in_channels: int,
self.attentions = nn.ModuleList(attentions) out_channels: int,
self.resnets = nn.ModuleList(resnets) dropout: float = 0.0,
num_layers: int = 1,
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: resnet_eps: float = 1e-6,
hidden_states = self.resnets[0](hidden_states, temb) resnet_time_scale_shift: str = "default",
for attn, resnet in zip(self.attentions, self.resnets[1:]): resnet_act_fn: str = "swish",
if attn is not None: resnet_groups: int = 32,
B, C, T, H, W = hidden_states.shape resnet_pre_norm: bool = True,
hidden_states = rearrange(hidden_states, "b c f h w -> b (f h w) c") output_scale_factor: float = 1.0,
attention_mask = prepare_causal_attention_mask( add_downsample: bool = True,
T, H * W, hidden_states.dtype, hidden_states.device, batch_size=B downsample_stride: int = 2,
) downsample_padding: int = 1,
hidden_states = attn(hidden_states, temb=temb, attention_mask=attention_mask) ):
hidden_states = rearrange(hidden_states, "b (f h w) c -> b c f h w", f=T, h=H, w=W) super().__init__()
hidden_states = resnet(hidden_states, temb) resnets = []
return hidden_states for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
class DownEncoderBlockCausal3D(nn.Module): ResnetBlockCausal3D(
def __init__( in_channels=in_channels,
self, out_channels=out_channels,
in_channels: int, temb_channels=None,
out_channels: int, eps=resnet_eps,
dropout: float = 0.0, groups=resnet_groups,
num_layers: int = 1, dropout=dropout,
resnet_eps: float = 1e-6, time_embedding_norm=resnet_time_scale_shift,
resnet_time_scale_shift: str = "default", non_linearity=resnet_act_fn,
resnet_act_fn: str = "swish", output_scale_factor=output_scale_factor,
resnet_groups: int = 32, pre_norm=resnet_pre_norm,
resnet_pre_norm: bool = True, )
output_scale_factor: float = 1.0, )
add_downsample: bool = True,
downsample_stride: int = 2, self.resnets = nn.ModuleList(resnets)
downsample_padding: int = 1,
): if add_downsample:
super().__init__() self.downsamplers = nn.ModuleList(
resnets = [] [
DownsampleCausal3D(
for i in range(num_layers): out_channels,
in_channels = in_channels if i == 0 else out_channels use_conv=True,
resnets.append( out_channels=out_channels,
ResnetBlockCausal3D( padding=downsample_padding,
in_channels=in_channels, name="op",
out_channels=out_channels, stride=downsample_stride,
temb_channels=None, )
eps=resnet_eps, ]
groups=resnet_groups, )
dropout=dropout, else:
time_embedding_norm=resnet_time_scale_shift, self.downsamplers = None
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor, def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
pre_norm=resnet_pre_norm, for resnet in self.resnets:
) hidden_states = resnet(hidden_states, temb=None, scale=scale)
)
if self.downsamplers is not None:
self.resnets = nn.ModuleList(resnets) for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states, scale)
if add_downsample:
self.downsamplers = nn.ModuleList( return hidden_states
[
DownsampleCausal3D(
out_channels, class UpDecoderBlockCausal3D(nn.Module):
use_conv=True, def __init__(
out_channels=out_channels, self,
padding=downsample_padding, in_channels: int,
name="op", out_channels: int,
stride=downsample_stride, resolution_idx: Optional[int] = None,
) dropout: float = 0.0,
] num_layers: int = 1,
) resnet_eps: float = 1e-6,
else: resnet_time_scale_shift: str = "default", # default, spatial
self.downsamplers = None resnet_act_fn: str = "swish",
resnet_groups: int = 32,
def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor: resnet_pre_norm: bool = True,
for resnet in self.resnets: output_scale_factor: float = 1.0,
hidden_states = resnet(hidden_states, temb=None, scale=scale) add_upsample: bool = True,
upsample_scale_factor=(2, 2, 2),
if self.downsamplers is not None: temb_channels: Optional[int] = None,
for downsampler in self.downsamplers: ):
hidden_states = downsampler(hidden_states, scale) super().__init__()
resnets = []
return hidden_states
for i in range(num_layers):
input_channels = in_channels if i == 0 else out_channels
class UpDecoderBlockCausal3D(nn.Module):
def __init__( resnets.append(
self, ResnetBlockCausal3D(
in_channels: int, in_channels=input_channels,
out_channels: int, out_channels=out_channels,
resolution_idx: Optional[int] = None, temb_channels=temb_channels,
dropout: float = 0.0, eps=resnet_eps,
num_layers: int = 1, groups=resnet_groups,
resnet_eps: float = 1e-6, dropout=dropout,
resnet_time_scale_shift: str = "default", # default, spatial time_embedding_norm=resnet_time_scale_shift,
resnet_act_fn: str = "swish", non_linearity=resnet_act_fn,
resnet_groups: int = 32, output_scale_factor=output_scale_factor,
resnet_pre_norm: bool = True, pre_norm=resnet_pre_norm,
output_scale_factor: float = 1.0, )
add_upsample: bool = True, )
upsample_scale_factor=(2, 2, 2),
temb_channels: Optional[int] = None, self.resnets = nn.ModuleList(resnets)
):
super().__init__() if add_upsample:
resnets = [] self.upsamplers = nn.ModuleList(
[
for i in range(num_layers): UpsampleCausal3D(
input_channels = in_channels if i == 0 else out_channels out_channels,
use_conv=True,
resnets.append( out_channels=out_channels,
ResnetBlockCausal3D( upsample_factor=upsample_scale_factor,
in_channels=input_channels, )
out_channels=out_channels, ]
temb_channels=temb_channels, )
eps=resnet_eps, else:
groups=resnet_groups, self.upsamplers = None
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift, self.resolution_idx = resolution_idx
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor, def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0) -> torch.FloatTensor:
pre_norm=resnet_pre_norm, for resnet in self.resnets:
) hidden_states = resnet(hidden_states, temb=temb, scale=scale)
)
if self.upsamplers is not None:
self.resnets = nn.ModuleList(resnets) for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
if add_upsample:
self.upsamplers = nn.ModuleList( return hidden_states
[
UpsampleCausal3D(
out_channels,
use_conv=True,
out_channels=out_channels,
upsample_factor=upsample_scale_factor,
)
]
)
else:
self.upsamplers = None
self.resolution_idx = resolution_idx
def forward(
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
) -> torch.FloatTensor:
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb=temb, scale=scale)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
return hidden_states
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from diffusers.utils import BaseOutput, is_torch_version from diffusers.utils import BaseOutput, is_torch_version
from diffusers.utils.torch_utils import randn_tensor from diffusers.utils.torch_utils import randn_tensor
from diffusers.models.attention_processor import SpatialNorm from diffusers.models.attention_processor import SpatialNorm
from .unet_causal_3d_blocks import ( from .unet_causal_3d_blocks import (
CausalConv3d, CausalConv3d,
UNetMidBlockCausal3D, UNetMidBlockCausal3D,
get_down_block3d, get_down_block3d,
get_up_block3d, get_up_block3d,
) )
@dataclass @dataclass
class DecoderOutput(BaseOutput): class DecoderOutput(BaseOutput):
r""" r"""
Output of decoding method. Output of decoding method.
Args: Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
The decoded output sample from the last layer of the model. The decoded output sample from the last layer of the model.
""" """
sample: torch.FloatTensor sample: torch.FloatTensor
class EncoderCausal3D(nn.Module): class EncoderCausal3D(nn.Module):
r""" r"""
The `EncoderCausal3D` layer of a variational autoencoder that encodes its input into a latent representation. The `EncoderCausal3D` layer of a variational autoencoder that encodes its input into a latent representation.
""" """
def __init__( def __init__(
self, self,
in_channels: int = 3, in_channels: int = 3,
out_channels: int = 3, out_channels: int = 3,
down_block_types: Tuple[str, ...] = ("DownEncoderBlockCausal3D",), down_block_types: Tuple[str, ...] = ("DownEncoderBlockCausal3D",),
block_out_channels: Tuple[int, ...] = (64,), block_out_channels: Tuple[int, ...] = (64,),
layers_per_block: int = 2, layers_per_block: int = 2,
norm_num_groups: int = 32, norm_num_groups: int = 32,
act_fn: str = "silu", act_fn: str = "silu",
double_z: bool = True, double_z: bool = True,
mid_block_add_attention=True, mid_block_add_attention=True,
time_compression_ratio: int = 4, time_compression_ratio: int = 4,
spatial_compression_ratio: int = 8, spatial_compression_ratio: int = 8,
): ):
super().__init__() super().__init__()
self.layers_per_block = layers_per_block self.layers_per_block = layers_per_block
self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1) self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1)
self.mid_block = None self.mid_block = None
self.down_blocks = nn.ModuleList([]) self.down_blocks = nn.ModuleList([])
# down # down
output_channel = block_out_channels[0] output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types): for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel input_channel = output_channel
output_channel = block_out_channels[i] output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1 is_final_block = i == len(block_out_channels) - 1
num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio)) num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio))
num_time_downsample_layers = int(np.log2(time_compression_ratio)) num_time_downsample_layers = int(np.log2(time_compression_ratio))
if time_compression_ratio == 4: if time_compression_ratio == 4:
add_spatial_downsample = bool(i < num_spatial_downsample_layers) add_spatial_downsample = bool(i < num_spatial_downsample_layers)
add_time_downsample = bool( add_time_downsample = bool(i >= (len(block_out_channels) - 1 - num_time_downsample_layers) and not is_final_block)
i >= (len(block_out_channels) - 1 - num_time_downsample_layers) else:
and not is_final_block raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.")
)
else: downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1)
raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.") downsample_stride_T = (2,) if add_time_downsample else (1,)
downsample_stride = tuple(downsample_stride_T + downsample_stride_HW)
downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1) down_block = get_down_block3d(
downsample_stride_T = (2,) if add_time_downsample else (1,) down_block_type,
downsample_stride = tuple(downsample_stride_T + downsample_stride_HW) num_layers=self.layers_per_block,
down_block = get_down_block3d( in_channels=input_channel,
down_block_type, out_channels=output_channel,
num_layers=self.layers_per_block, add_downsample=bool(add_spatial_downsample or add_time_downsample),
in_channels=input_channel, downsample_stride=downsample_stride,
out_channels=output_channel, resnet_eps=1e-6,
add_downsample=bool(add_spatial_downsample or add_time_downsample), downsample_padding=0,
downsample_stride=downsample_stride, resnet_act_fn=act_fn,
resnet_eps=1e-6, resnet_groups=norm_num_groups,
downsample_padding=0, attention_head_dim=output_channel,
resnet_act_fn=act_fn, temb_channels=None,
resnet_groups=norm_num_groups, )
attention_head_dim=output_channel, self.down_blocks.append(down_block)
temb_channels=None,
) # mid
self.down_blocks.append(down_block) self.mid_block = UNetMidBlockCausal3D(
in_channels=block_out_channels[-1],
# mid resnet_eps=1e-6,
self.mid_block = UNetMidBlockCausal3D( resnet_act_fn=act_fn,
in_channels=block_out_channels[-1], output_scale_factor=1,
resnet_eps=1e-6, resnet_time_scale_shift="default",
resnet_act_fn=act_fn, attention_head_dim=block_out_channels[-1],
output_scale_factor=1, resnet_groups=norm_num_groups,
resnet_time_scale_shift="default", temb_channels=None,
attention_head_dim=block_out_channels[-1], add_attention=mid_block_add_attention,
resnet_groups=norm_num_groups, )
temb_channels=None,
add_attention=mid_block_add_attention, # out
) self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU()
# out
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) conv_out_channels = 2 * out_channels if double_z else out_channels
self.conv_act = nn.SiLU() self.conv_out = CausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3)
conv_out_channels = 2 * out_channels if double_z else out_channels def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
self.conv_out = CausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3) r"""The forward method of the `EncoderCausal3D` class."""
assert len(sample.shape) == 5, "The input tensor should have 5 dimensions"
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
r"""The forward method of the `EncoderCausal3D` class.""" sample = self.conv_in(sample)
assert len(sample.shape) == 5, "The input tensor should have 5 dimensions"
# down
sample = self.conv_in(sample) for down_block in self.down_blocks:
sample = down_block(sample)
# down
for down_block in self.down_blocks: # middle
sample = down_block(sample) sample = self.mid_block(sample)
# middle # post-process
sample = self.mid_block(sample) sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
# post-process sample = self.conv_out(sample)
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample) return sample
sample = self.conv_out(sample)
return sample class DecoderCausal3D(nn.Module):
r"""
The `DecoderCausal3D` layer of a variational autoencoder that decodes its latent representation into an output sample.
class DecoderCausal3D(nn.Module): """
r"""
The `DecoderCausal3D` layer of a variational autoencoder that decodes its latent representation into an output sample. def __init__(
""" self,
in_channels: int = 3,
def __init__( out_channels: int = 3,
self, up_block_types: Tuple[str, ...] = ("UpDecoderBlockCausal3D",),
in_channels: int = 3, block_out_channels: Tuple[int, ...] = (64,),
out_channels: int = 3, layers_per_block: int = 2,
up_block_types: Tuple[str, ...] = ("UpDecoderBlockCausal3D",), norm_num_groups: int = 32,
block_out_channels: Tuple[int, ...] = (64,), act_fn: str = "silu",
layers_per_block: int = 2, norm_type: str = "group", # group, spatial
norm_num_groups: int = 32, mid_block_add_attention=True,
act_fn: str = "silu", time_compression_ratio: int = 4,
norm_type: str = "group", # group, spatial spatial_compression_ratio: int = 8,
mid_block_add_attention=True, ):
time_compression_ratio: int = 4, super().__init__()
spatial_compression_ratio: int = 8, self.layers_per_block = layers_per_block
):
super().__init__() self.conv_in = CausalConv3d(in_channels, block_out_channels[-1], kernel_size=3, stride=1)
self.layers_per_block = layers_per_block self.mid_block = None
self.up_blocks = nn.ModuleList([])
self.conv_in = CausalConv3d(in_channels, block_out_channels[-1], kernel_size=3, stride=1)
self.mid_block = None temb_channels = in_channels if norm_type == "spatial" else None
self.up_blocks = nn.ModuleList([])
# mid
temb_channels = in_channels if norm_type == "spatial" else None self.mid_block = UNetMidBlockCausal3D(
in_channels=block_out_channels[-1],
# mid resnet_eps=1e-6,
self.mid_block = UNetMidBlockCausal3D( resnet_act_fn=act_fn,
in_channels=block_out_channels[-1], output_scale_factor=1,
resnet_eps=1e-6, resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
resnet_act_fn=act_fn, attention_head_dim=block_out_channels[-1],
output_scale_factor=1, resnet_groups=norm_num_groups,
resnet_time_scale_shift="default" if norm_type == "group" else norm_type, temb_channels=temb_channels,
attention_head_dim=block_out_channels[-1], add_attention=mid_block_add_attention,
resnet_groups=norm_num_groups, )
temb_channels=temb_channels,
add_attention=mid_block_add_attention, # up
) reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
# up for i, up_block_type in enumerate(up_block_types):
reversed_block_out_channels = list(reversed(block_out_channels)) prev_output_channel = output_channel
output_channel = reversed_block_out_channels[0] output_channel = reversed_block_out_channels[i]
for i, up_block_type in enumerate(up_block_types): is_final_block = i == len(block_out_channels) - 1
prev_output_channel = output_channel num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio))
output_channel = reversed_block_out_channels[i] num_time_upsample_layers = int(np.log2(time_compression_ratio))
is_final_block = i == len(block_out_channels) - 1
num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio)) if time_compression_ratio == 4:
num_time_upsample_layers = int(np.log2(time_compression_ratio)) add_spatial_upsample = bool(i < num_spatial_upsample_layers)
add_time_upsample = bool(i >= len(block_out_channels) - 1 - num_time_upsample_layers and not is_final_block)
if time_compression_ratio == 4: else:
add_spatial_upsample = bool(i < num_spatial_upsample_layers) raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.")
add_time_upsample = bool(
i >= len(block_out_channels) - 1 - num_time_upsample_layers upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1)
and not is_final_block upsample_scale_factor_T = (2,) if add_time_upsample else (1,)
) upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW)
else: up_block = get_up_block3d(
raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.") up_block_type,
num_layers=self.layers_per_block + 1,
upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1) in_channels=prev_output_channel,
upsample_scale_factor_T = (2,) if add_time_upsample else (1,) out_channels=output_channel,
upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW) prev_output_channel=None,
up_block = get_up_block3d( add_upsample=bool(add_spatial_upsample or add_time_upsample),
up_block_type, upsample_scale_factor=upsample_scale_factor,
num_layers=self.layers_per_block + 1, resnet_eps=1e-6,
in_channels=prev_output_channel, resnet_act_fn=act_fn,
out_channels=output_channel, resnet_groups=norm_num_groups,
prev_output_channel=None, attention_head_dim=output_channel,
add_upsample=bool(add_spatial_upsample or add_time_upsample), temb_channels=temb_channels,
upsample_scale_factor=upsample_scale_factor, resnet_time_scale_shift=norm_type,
resnet_eps=1e-6, )
resnet_act_fn=act_fn, self.up_blocks.append(up_block)
resnet_groups=norm_num_groups, prev_output_channel = output_channel
attention_head_dim=output_channel,
temb_channels=temb_channels, # out
resnet_time_scale_shift=norm_type, if norm_type == "spatial":
) self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
self.up_blocks.append(up_block) else:
prev_output_channel = output_channel self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU()
# out self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3)
if norm_type == "spatial":
self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) self.gradient_checkpointing = False
else:
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) def forward(
self.conv_act = nn.SiLU() self,
self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3) sample: torch.FloatTensor,
latent_embeds: Optional[torch.FloatTensor] = None,
self.gradient_checkpointing = False ) -> torch.FloatTensor:
r"""The forward method of the `DecoderCausal3D` class."""
def forward( assert len(sample.shape) == 5, "The input tensor should have 5 dimensions."
self,
sample: torch.FloatTensor, sample = self.conv_in(sample)
latent_embeds: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor: upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
r"""The forward method of the `DecoderCausal3D` class.""" if self.training and self.gradient_checkpointing:
assert len(sample.shape) == 5, "The input tensor should have 5 dimensions."
def create_custom_forward(module):
sample = self.conv_in(sample) def custom_forward(*inputs):
return module(*inputs)
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
if self.training and self.gradient_checkpointing: return custom_forward
def create_custom_forward(module): if is_torch_version(">=", "1.11.0"):
def custom_forward(*inputs): # middle
return module(*inputs) sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block),
return custom_forward sample,
latent_embeds,
if is_torch_version(">=", "1.11.0"): use_reentrant=False,
# middle )
sample = torch.utils.checkpoint.checkpoint( sample = sample.to(upscale_dtype)
create_custom_forward(self.mid_block),
sample, # up
latent_embeds, for up_block in self.up_blocks:
use_reentrant=False, sample = torch.utils.checkpoint.checkpoint(
) create_custom_forward(up_block),
sample = sample.to(upscale_dtype) sample,
latent_embeds,
# up use_reentrant=False,
for up_block in self.up_blocks: )
sample = torch.utils.checkpoint.checkpoint( else:
create_custom_forward(up_block), # middle
sample, sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample, latent_embeds)
latent_embeds, sample = sample.to(upscale_dtype)
use_reentrant=False,
) # up
else: for up_block in self.up_blocks:
# middle sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
sample = torch.utils.checkpoint.checkpoint( else:
create_custom_forward(self.mid_block), sample, latent_embeds # middle
) sample = self.mid_block(sample, latent_embeds)
sample = sample.to(upscale_dtype) sample = sample.to(upscale_dtype)
# up # up
for up_block in self.up_blocks: for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds) sample = up_block(sample, latent_embeds)
else:
# middle # post-process
sample = self.mid_block(sample, latent_embeds) if latent_embeds is None:
sample = sample.to(upscale_dtype) sample = self.conv_norm_out(sample)
else:
# up sample = self.conv_norm_out(sample, latent_embeds)
for up_block in self.up_blocks: sample = self.conv_act(sample)
sample = up_block(sample, latent_embeds) sample = self.conv_out(sample)
# post-process return sample
if latent_embeds is None:
sample = self.conv_norm_out(sample)
else: class DiagonalGaussianDistribution(object):
sample = self.conv_norm_out(sample, latent_embeds) def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
sample = self.conv_act(sample) if parameters.ndim == 3:
sample = self.conv_out(sample) dim = 2 # (B, L, C)
elif parameters.ndim == 5 or parameters.ndim == 4:
return sample dim = 1 # (B, C, T, H ,W) / (B, C, H, W)
else:
raise NotImplementedError
class DiagonalGaussianDistribution(object): self.parameters = parameters
def __init__(self, parameters: torch.Tensor, deterministic: bool = False): self.mean, self.logvar = torch.chunk(parameters, 2, dim=dim)
if parameters.ndim == 3: self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
dim = 2 # (B, L, C) self.deterministic = deterministic
elif parameters.ndim == 5 or parameters.ndim == 4: self.std = torch.exp(0.5 * self.logvar)
dim = 1 # (B, C, T, H ,W) / (B, C, H, W) self.var = torch.exp(self.logvar)
else: if self.deterministic:
raise NotImplementedError self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device, dtype=self.parameters.dtype)
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=dim) def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
self.logvar = torch.clamp(self.logvar, -30.0, 20.0) # make sure sample is on the same device as the parameters and has same dtype
self.deterministic = deterministic sample = randn_tensor(
self.std = torch.exp(0.5 * self.logvar) self.mean.shape,
self.var = torch.exp(self.logvar) generator=generator,
if self.deterministic: device=self.parameters.device,
self.var = self.std = torch.zeros_like( dtype=self.parameters.dtype,
self.mean, device=self.parameters.device, dtype=self.parameters.dtype )
) x = self.mean + self.std * sample
return x
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
# make sure sample is on the same device as the parameters and has same dtype def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
sample = randn_tensor( if self.deterministic:
self.mean.shape, return torch.Tensor([0.0])
generator=generator, else:
device=self.parameters.device, reduce_dim = list(range(1, self.mean.ndim))
dtype=self.parameters.dtype, if other is None:
) return 0.5 * torch.sum(
x = self.mean + self.std * sample torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
return x dim=reduce_dim,
)
def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor: else:
if self.deterministic: return 0.5 * torch.sum(
return torch.Tensor([0.0]) torch.pow(self.mean - other.mean, 2) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar,
else: dim=reduce_dim,
reduce_dim = list(range(1, self.mean.ndim)) )
if other is None:
return 0.5 * torch.sum( def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, if self.deterministic:
dim=reduce_dim, return torch.Tensor([0.0])
) logtwopi = np.log(2.0 * np.pi)
else: return 0.5 * torch.sum(
return 0.5 * torch.sum( logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
torch.pow(self.mean - other.mean, 2) / other.var dim=dims,
+ self.var / other.var )
- 1.0
- self.logvar def mode(self) -> torch.Tensor:
+ other.logvar, return self.mean
dim=reduce_dim,
)
def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
if self.deterministic:
return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar +
torch.pow(sample - self.mean, 2) / self.var,
dim=dims,
)
def mode(self) -> torch.Tensor:
return self.mean
...@@ -44,7 +44,6 @@ class CausalConv3d(nn.Conv3d): ...@@ -44,7 +44,6 @@ class CausalConv3d(nn.Conv3d):
class RMS_norm(nn.Module): class RMS_norm(nn.Module):
def __init__(self, dim, channel_first=True, images=True, bias=False): def __init__(self, dim, channel_first=True, images=True, bias=False):
super().__init__() super().__init__()
broadcastable_dims = (1, 1, 1) if not images else (1, 1) broadcastable_dims = (1, 1, 1) if not images else (1, 1)
...@@ -56,16 +55,10 @@ class RMS_norm(nn.Module): ...@@ -56,16 +55,10 @@ class RMS_norm(nn.Module):
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
def forward(self, x): def forward(self, x):
return ( return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
F.normalize(x, dim=(1 if self.channel_first else -1))
* self.scale
* self.gamma
+ self.bias
)
class Upsample(nn.Upsample): class Upsample(nn.Upsample):
def forward(self, x): def forward(self, x):
""" """
Fix bfloat16 support for nearest neighbor interpolation. Fix bfloat16 support for nearest neighbor interpolation.
...@@ -74,7 +67,6 @@ class Upsample(nn.Upsample): ...@@ -74,7 +67,6 @@ class Upsample(nn.Upsample):
class Resample(nn.Module): class Resample(nn.Module):
def __init__(self, dim, mode): def __init__(self, dim, mode):
assert mode in ( assert mode in (
"none", "none",
...@@ -101,16 +93,10 @@ class Resample(nn.Module): ...@@ -101,16 +93,10 @@ class Resample(nn.Module):
self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
elif mode == "downsample2d": elif mode == "downsample2d":
self.resample = nn.Sequential( self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))
)
elif mode == "downsample3d": elif mode == "downsample3d":
self.resample = nn.Sequential( self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)) self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
)
self.time_conv = CausalConv3d(
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)
)
else: else:
self.resample = nn.Identity() self.resample = nn.Identity()
...@@ -124,28 +110,17 @@ class Resample(nn.Module): ...@@ -124,28 +110,17 @@ class Resample(nn.Module):
feat_cache[idx] = "Rep" feat_cache[idx] = "Rep"
feat_idx[0] += 1 feat_idx[0] += 1
else: else:
cache_x = x[:, :, -CACHE_T:, :, :].clone() cache_x = x[:, :, -CACHE_T:, :, :].clone()
if ( if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
cache_x.shape[2] < 2
and feat_cache[idx] is not None
and feat_cache[idx] != "Rep"
):
# cache last frame of last two chunk # cache last frame of last two chunk
cache_x = torch.cat( cache_x = torch.cat(
[ [
feat_cache[idx][:, :, -1, :, :] feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
.unsqueeze(2)
.to(cache_x.device),
cache_x, cache_x,
], ],
dim=2, dim=2,
) )
if ( if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep":
cache_x.shape[2] < 2
and feat_cache[idx] is not None
and feat_cache[idx] == "Rep"
):
cache_x = torch.cat( cache_x = torch.cat(
[torch.zeros_like(cache_x).to(cache_x.device), cache_x], [torch.zeros_like(cache_x).to(cache_x.device), cache_x],
dim=2, dim=2,
...@@ -172,15 +147,12 @@ class Resample(nn.Module): ...@@ -172,15 +147,12 @@ class Resample(nn.Module):
feat_cache[idx] = x.clone() feat_cache[idx] = x.clone()
feat_idx[0] += 1 feat_idx[0] += 1
else: else:
cache_x = x[:, :, -1:, :, :].clone() cache_x = x[:, :, -1:, :, :].clone()
# if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep': # if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
# # cache last frame of last two chunk # # cache last frame of last two chunk
# cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) # cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.time_conv( x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)
)
feat_cache[idx] = cache_x feat_cache[idx] = cache_x
feat_idx[0] += 1 feat_idx[0] += 1
return x return x
...@@ -210,7 +182,6 @@ class Resample(nn.Module): ...@@ -210,7 +182,6 @@ class Resample(nn.Module):
class ResidualBlock(nn.Module): class ResidualBlock(nn.Module):
def __init__(self, in_dim, out_dim, dropout=0.0): def __init__(self, in_dim, out_dim, dropout=0.0):
super().__init__() super().__init__()
self.in_dim = in_dim self.in_dim = in_dim
...@@ -226,9 +197,7 @@ class ResidualBlock(nn.Module): ...@@ -226,9 +197,7 @@ class ResidualBlock(nn.Module):
nn.Dropout(dropout), nn.Dropout(dropout),
CausalConv3d(out_dim, out_dim, 3, padding=1), CausalConv3d(out_dim, out_dim, 3, padding=1),
) )
self.shortcut = ( self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
)
def forward(self, x, feat_cache=None, feat_idx=[0]): def forward(self, x, feat_cache=None, feat_idx=[0]):
h = self.shortcut(x) h = self.shortcut(x)
...@@ -240,9 +209,7 @@ class ResidualBlock(nn.Module): ...@@ -240,9 +209,7 @@ class ResidualBlock(nn.Module):
# cache last frame of last two chunk # cache last frame of last two chunk
cache_x = torch.cat( cache_x = torch.cat(
[ [
feat_cache[idx][:, :, -1, :, :] feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
.unsqueeze(2)
.to(cache_x.device),
cache_x, cache_x,
], ],
dim=2, dim=2,
...@@ -278,13 +245,7 @@ class AttentionBlock(nn.Module): ...@@ -278,13 +245,7 @@ class AttentionBlock(nn.Module):
x = rearrange(x, "b c t h w -> (b t) c h w") x = rearrange(x, "b c t h w -> (b t) c h w")
x = self.norm(x) x = self.norm(x)
# compute query, key, value # compute query, key, value
q, k, v = ( q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(0, 1, 3, 2).contiguous().chunk(3, dim=-1)
self.to_qkv(x)
.reshape(b * t, 1, c * 3, -1)
.permute(0, 1, 3, 2)
.contiguous()
.chunk(3, dim=-1)
)
# apply attention # apply attention
x = F.scaled_dot_product_attention( x = F.scaled_dot_product_attention(
...@@ -301,7 +262,6 @@ class AttentionBlock(nn.Module): ...@@ -301,7 +262,6 @@ class AttentionBlock(nn.Module):
class Encoder3d(nn.Module): class Encoder3d(nn.Module):
def __init__( def __init__(
self, self,
dim=128, dim=128,
...@@ -400,9 +360,7 @@ class Encoder3d(nn.Module): ...@@ -400,9 +360,7 @@ class Encoder3d(nn.Module):
# cache last frame of last two chunk # cache last frame of last two chunk
cache_x = torch.cat( cache_x = torch.cat(
[ [
feat_cache[idx][:, :, -1, :, :] feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
.unsqueeze(2)
.to(cache_x.device),
cache_x, cache_x,
], ],
dim=2, dim=2,
...@@ -416,7 +374,6 @@ class Encoder3d(nn.Module): ...@@ -416,7 +374,6 @@ class Encoder3d(nn.Module):
class Decoder3d(nn.Module): class Decoder3d(nn.Module):
def __init__( def __init__(
self, self,
dim=128, dim=128,
...@@ -518,9 +475,7 @@ class Decoder3d(nn.Module): ...@@ -518,9 +475,7 @@ class Decoder3d(nn.Module):
# cache last frame of last two chunk # cache last frame of last two chunk
cache_x = torch.cat( cache_x = torch.cat(
[ [
feat_cache[idx][:, :, -1, :, :] feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
.unsqueeze(2)
.to(cache_x.device),
cache_x, cache_x,
], ],
dim=2, dim=2,
...@@ -542,7 +497,6 @@ def count_conv3d(model): ...@@ -542,7 +497,6 @@ def count_conv3d(model):
class WanVAE_(nn.Module): class WanVAE_(nn.Module):
def __init__( def __init__(
self, self,
dim=128, dim=128,
...@@ -613,9 +567,7 @@ class WanVAE_(nn.Module): ...@@ -613,9 +567,7 @@ class WanVAE_(nn.Module):
out = torch.cat([out, out_], 2) out = torch.cat([out, out_], 2)
mu, log_var = self.conv1(out).chunk(2, dim=1) mu, log_var = self.conv1(out).chunk(2, dim=1)
if isinstance(scale[0], torch.Tensor): if isinstance(scale[0], torch.Tensor):
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view( mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1)
1, self.z_dim, 1, 1, 1
)
else: else:
mu = (mu - scale[0]) * scale[1] mu = (mu - scale[0]) * scale[1]
self.clear_cache() self.clear_cache()
...@@ -625,9 +577,7 @@ class WanVAE_(nn.Module): ...@@ -625,9 +577,7 @@ class WanVAE_(nn.Module):
self.clear_cache() self.clear_cache()
# z: [b,c,t,h,w] # z: [b,c,t,h,w]
if isinstance(scale[0], torch.Tensor): if isinstance(scale[0], torch.Tensor):
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view( z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1)
1, self.z_dim, 1, 1, 1
)
else: else:
z = z / scale[1] + scale[0] z = z / scale[1] + scale[0]
iter_ = z.shape[2] iter_ = z.shape[2]
...@@ -700,7 +650,6 @@ def _video_vae(pretrained_path=None, z_dim=None, device="cpu", **kwargs): ...@@ -700,7 +650,6 @@ def _video_vae(pretrained_path=None, z_dim=None, device="cpu", **kwargs):
class WanVAE: class WanVAE:
def __init__( def __init__(
self, self,
z_dim=16, z_dim=16,
...@@ -780,11 +729,8 @@ class WanVAE: ...@@ -780,11 +729,8 @@ class WanVAE:
""" """
videos: A list of videos each with shape [C, T, H, W]. videos: A list of videos each with shape [C, T, H, W].
""" """
return [ return [self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0) for u in videos]
self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0)
for u in videos
]
def decode_dist(self, zs, world_size, cur_rank, split_dim): def decode_dist(self, zs, world_size, cur_rank, split_dim):
splited_total_len = zs.shape[split_dim] splited_total_len = zs.shape[split_dim]
splited_chunk_len = splited_total_len // world_size splited_chunk_len = splited_total_len // world_size
...@@ -792,37 +738,37 @@ class WanVAE: ...@@ -792,37 +738,37 @@ class WanVAE:
if cur_rank == 0: if cur_rank == 0:
if split_dim == 2: if split_dim == 2:
zs = zs[:,:,:splited_chunk_len+2*padding_size,:].contiguous() zs = zs[:, :, : splited_chunk_len + 2 * padding_size, :].contiguous()
elif split_dim == 3: elif split_dim == 3:
zs = zs[:,:,:,:splited_chunk_len+2*padding_size].contiguous() zs = zs[:, :, :, : splited_chunk_len + 2 * padding_size].contiguous()
elif cur_rank == world_size-1: elif cur_rank == world_size - 1:
if split_dim == 2: if split_dim == 2:
zs = zs[:,:,-(splited_chunk_len+2*padding_size):,:].contiguous() zs = zs[:, :, -(splited_chunk_len + 2 * padding_size) :, :].contiguous()
elif split_dim == 3: elif split_dim == 3:
zs = zs[:,:,:,-(splited_chunk_len+2*padding_size):].contiguous() zs = zs[:, :, :, -(splited_chunk_len + 2 * padding_size) :].contiguous()
else: else:
if split_dim == 2: if split_dim == 2:
zs = zs[:,:,cur_rank*splited_chunk_len-padding_size:(cur_rank+1)*splited_chunk_len+padding_size,:].contiguous() zs = zs[:, :, cur_rank * splited_chunk_len - padding_size : (cur_rank + 1) * splited_chunk_len + padding_size, :].contiguous()
elif split_dim == 3: elif split_dim == 3:
zs = zs[:,:,:,cur_rank*splited_chunk_len-padding_size:(cur_rank+1)*splited_chunk_len+padding_size].contiguous() zs = zs[:, :, :, cur_rank * splited_chunk_len - padding_size : (cur_rank + 1) * splited_chunk_len + padding_size].contiguous()
images = self.model.decode(zs.unsqueeze(0), self.scale).float().clamp_(-1, 1) images = self.model.decode(zs.unsqueeze(0), self.scale).float().clamp_(-1, 1)
if cur_rank == 0: if cur_rank == 0:
if split_dim == 2: if split_dim == 2:
images = images[:,:,:,:splited_chunk_len*8,:].contiguous() images = images[:, :, :, : splited_chunk_len * 8, :].contiguous()
elif split_dim == 3: elif split_dim == 3:
images = images[:,:,:,:,:splited_chunk_len*8].contiguous() images = images[:, :, :, :, : splited_chunk_len * 8].contiguous()
elif cur_rank == world_size-1: elif cur_rank == world_size - 1:
if split_dim == 2: if split_dim == 2:
images = images[:,:,:,-splited_chunk_len*8:,:].contiguous() images = images[:, :, :, -splited_chunk_len * 8 :, :].contiguous()
elif split_dim == 3: elif split_dim == 3:
images = images[:,:,:,:,-splited_chunk_len*8:].contiguous() images = images[:, :, :, :, -splited_chunk_len * 8 :].contiguous()
else: else:
if split_dim == 2: if split_dim == 2:
images = images[:,:,:,8*padding_size:-8*padding_size,:].contiguous() images = images[:, :, :, 8 * padding_size : -8 * padding_size, :].contiguous()
elif split_dim == 3: elif split_dim == 3:
images = images[:,:,:,:,8*padding_size:-8*padding_size].contiguous() images = images[:, :, :, :, 8 * padding_size : -8 * padding_size].contiguous()
full_images = [torch.empty_like(images) for _ in range(world_size)] full_images = [torch.empty_like(images) for _ in range(world_size)]
dist.all_gather(full_images, images) dist.all_gather(full_images, images)
...@@ -832,7 +778,6 @@ class WanVAE: ...@@ -832,7 +778,6 @@ class WanVAE:
images = torch.cat(full_images, dim=-1) images = torch.cat(full_images, dim=-1)
return images return images
def decode(self, zs, generator, args): def decode(self, zs, generator, args):
if args.cpu_offload: if args.cpu_offload:
......
...@@ -5,7 +5,7 @@ from lightx2v.text2v.models.video_encoders.hf.autoencoder_kl_causal_3d.autoencod ...@@ -5,7 +5,7 @@ from lightx2v.text2v.models.video_encoders.hf.autoencoder_kl_causal_3d.autoencod
from lightx2v.text2v.models.video_encoders.trt.autoencoder_kl_causal_3d import trt_vae_infer from lightx2v.text2v.models.video_encoders.trt.autoencoder_kl_causal_3d import trt_vae_infer
class VideoEncoderKLCausal3DModel(): class VideoEncoderKLCausal3DModel:
def __init__(self, model_path, dtype, device): def __init__(self, model_path, dtype, device):
self.model_path = model_path self.model_path = model_path
self.dtype = dtype self.dtype = dtype
...@@ -13,10 +13,10 @@ class VideoEncoderKLCausal3DModel(): ...@@ -13,10 +13,10 @@ class VideoEncoderKLCausal3DModel():
self.load() self.load()
def load(self): def load(self):
self.vae_path = os.path.join(self.model_path, 'hunyuan-video-t2v-720p/vae') self.vae_path = os.path.join(self.model_path, "hunyuan-video-t2v-720p/vae")
config = AutoencoderKLCausal3D.load_config(self.vae_path) config = AutoencoderKLCausal3D.load_config(self.vae_path)
self.model = AutoencoderKLCausal3D.from_config(config) self.model = AutoencoderKLCausal3D.from_config(config)
ckpt = torch.load(os.path.join(self.vae_path, 'pytorch_model.pt'), map_location='cpu', weights_only=True) ckpt = torch.load(os.path.join(self.vae_path, "pytorch_model.pt"), map_location="cpu", weights_only=True)
self.model.load_state_dict(ckpt) self.model.load_state_dict(ckpt)
self.model = self.model.to(dtype=self.dtype, device=self.device) self.model = self.model.to(dtype=self.dtype, device=self.device)
self.model.requires_grad_(False) self.model.requires_grad_(False)
...@@ -28,12 +28,11 @@ class VideoEncoderKLCausal3DModel(): ...@@ -28,12 +28,11 @@ class VideoEncoderKLCausal3DModel():
latents = latents / self.model.config.scaling_factor latents = latents / self.model.config.scaling_factor
latents = latents.to(dtype=self.dtype, device=self.device) latents = latents.to(dtype=self.dtype, device=self.device)
self.model.enable_tiling() self.model.enable_tiling()
image = self.model.decode( image = self.model.decode(latents, return_dict=False, generator=generator)[0]
latents, return_dict=False, generator=generator
)[0]
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().float() image = image.cpu().float()
return image return image
if __name__ == "__main__": if __name__ == "__main__":
vae_model = VideoEncoderKLCausal3DModel("/mnt/nvme1/yongyang/models/hy/ckpts", dtype=torch.float16, device=torch.device("cuda")) vae_model = VideoEncoderKLCausal3DModel("/mnt/nvme1/yongyang/models/hy/ckpts", dtype=torch.float16, device=torch.device("cuda"))
...@@ -100,18 +100,18 @@ class HyVaeTrtModelInfer(nn.Module): ...@@ -100,18 +100,18 @@ class HyVaeTrtModelInfer(nn.Module):
device = batch.device device = batch.device
dtype = batch.dtype dtype = batch.dtype
batch = batch.cpu().numpy() batch = batch.cpu().numpy()
def get_output_shape(shp): def get_output_shape(shp):
b, c, t, h, w = shp b, c, t, h, w = shp
out = (b, 3, 4*(t-1)+1, h*8, w*8) out = (b, 3, 4 * (t - 1) + 1, h * 8, w * 8)
return out return out
shp_dict = {"inp": batch.shape, "out": get_output_shape(batch.shape)} shp_dict = {"inp": batch.shape, "out": get_output_shape(batch.shape)}
self.alloc(shp_dict) self.alloc(shp_dict)
output = np.zeros(*self.output_spec()) output = np.zeros(*self.output_spec())
# Process I/O and execute the network # Process I/O and execute the network
common.memcpy_host_to_device( common.memcpy_host_to_device(self.inputs[0]["allocation"], np.ascontiguousarray(batch))
self.inputs[0]["allocation"], np.ascontiguousarray(batch)
)
self.context.execute_v2(self.allocations) self.context.execute_v2(self.allocations)
common.memcpy_device_to_host(output, self.outputs[0]["allocation"]) common.memcpy_device_to_host(output, self.outputs[0]["allocation"])
output = torch.from_numpy(output).to(device).type(dtype) output = torch.from_numpy(output).to(device).type(dtype)
...@@ -122,19 +122,16 @@ class HyVaeTrtModelInfer(nn.Module): ...@@ -122,19 +122,16 @@ class HyVaeTrtModelInfer(nn.Module):
logger.info("Start to do VAE onnx exporting.") logger.info("Start to do VAE onnx exporting.")
device = next(decoder.parameters())[0].device device = next(decoder.parameters())[0].device
example_inp = torch.rand(1, 16, 17, 32, 32).to(device).type(next(decoder.parameters())[0].dtype) example_inp = torch.rand(1, 16, 17, 32, 32).to(device).type(next(decoder.parameters())[0].dtype)
out_path = str(Path(str(model_dir))/"vae_decoder.onnx") out_path = str(Path(str(model_dir)) / "vae_decoder.onnx")
torch.onnx.export( torch.onnx.export(
decoder.eval().half(), decoder.eval().half(),
example_inp.half(), example_inp.half(),
out_path, out_path,
input_names=['inp'], input_names=["inp"],
output_names=['out'], output_names=["out"],
opset_version=14, opset_version=14,
dynamic_axes={ dynamic_axes={"inp": {1: "c1", 2: "c2", 3: "c3", 4: "c4"}, "out": {1: "c1", 2: "c2", 3: "c3", 4: "c4"}},
"inp": {1: "c1", 2: "c2", 3: "c3", 4: "c4"}, )
"out": {1: "c1", 2: "c2", 3: "c3", 4: "c4"}
}
)
# onnx_ori = onnx.load(out_path) # onnx_ori = onnx.load(out_path)
os.system(f"onnxsim {out_path} {out_path}") os.system(f"onnxsim {out_path} {out_path}")
# onnx_opt, check = simplify(onnx_ori) # onnx_opt, check = simplify(onnx_ori)
...@@ -163,4 +160,4 @@ class HyVaeTrtModelInfer(nn.Module): ...@@ -163,4 +160,4 @@ class HyVaeTrtModelInfer(nn.Module):
if not Path(engine_path).exists(): if not Path(engine_path).exists():
raise RuntimeError(f"Convert vae onnx({onnx_path}) to tensorrt engine failed.") raise RuntimeError(f"Convert vae onnx({onnx_path}) to tensorrt engine failed.")
logger.info("Finish VAE tensorrt converting.") logger.info("Finish VAE tensorrt converting.")
return engine_path return engine_path
\ No newline at end of file
...@@ -8,20 +8,20 @@ class BaseQuantizer(object): ...@@ -8,20 +8,20 @@ class BaseQuantizer(object):
self.sym = symmetric self.sym = symmetric
self.granularity = granularity self.granularity = granularity
self.kwargs = kwargs self.kwargs = kwargs
if self.granularity == 'per_group': if self.granularity == "per_group":
self.group_size = self.kwargs['group_size'] self.group_size = self.kwargs["group_size"]
self.calib_algo = self.kwargs.get('calib_algo', 'minmax') self.calib_algo = self.kwargs.get("calib_algo", "minmax")
def get_tensor_range(self, tensor): def get_tensor_range(self, tensor):
if self.calib_algo == 'minmax': if self.calib_algo == "minmax":
return self.get_minmax_range(tensor) return self.get_minmax_range(tensor)
elif self.calib_algo == 'mse': elif self.calib_algo == "mse":
return self.get_mse_range(tensor) return self.get_mse_range(tensor)
else: else:
raise ValueError(f'Unsupported calibration algorithm: {self.calib_algo}') raise ValueError(f"Unsupported calibration algorithm: {self.calib_algo}")
def get_minmax_range(self, tensor): def get_minmax_range(self, tensor):
if self.granularity == 'per_tensor': if self.granularity == "per_tensor":
max_val = torch.max(tensor) max_val = torch.max(tensor)
min_val = torch.min(tensor) min_val = torch.min(tensor)
else: else:
...@@ -47,7 +47,7 @@ class BaseQuantizer(object): ...@@ -47,7 +47,7 @@ class BaseQuantizer(object):
return scales, zeros, qmax, qmin return scales, zeros, qmax, qmin
def reshape_tensor(self, tensor, allow_padding=False): def reshape_tensor(self, tensor, allow_padding=False):
if self.granularity == 'per_group': if self.granularity == "per_group":
t = tensor.reshape(-1, self.group_size) t = tensor.reshape(-1, self.group_size)
else: else:
t = tensor t = tensor
...@@ -79,7 +79,7 @@ class BaseQuantizer(object): ...@@ -79,7 +79,7 @@ class BaseQuantizer(object):
tensor, scales, zeros, qmax, qmin = self.get_tensor_qparams(tensor) tensor, scales, zeros, qmax, qmin = self.get_tensor_qparams(tensor)
tensor = self.quant(tensor, scales, zeros, qmax, qmin) tensor = self.quant(tensor, scales, zeros, qmax, qmin)
tensor = self.restore_tensor(tensor, org_shape) tensor = self.restore_tensor(tensor, org_shape)
if self.sym == True: if self.sym:
zeros = None zeros = None
return tensor, scales, zeros return tensor, scales, zeros
...@@ -87,9 +87,9 @@ class BaseQuantizer(object): ...@@ -87,9 +87,9 @@ class BaseQuantizer(object):
class IntegerQuantizer(BaseQuantizer): class IntegerQuantizer(BaseQuantizer):
def __init__(self, bit, symmetric, granularity, **kwargs): def __init__(self, bit, symmetric, granularity, **kwargs):
super().__init__(bit, symmetric, granularity, **kwargs) super().__init__(bit, symmetric, granularity, **kwargs)
if 'int_range' in self.kwargs: if "int_range" in self.kwargs:
self.qmin = self.kwargs['int_range'][0] self.qmin = self.kwargs["int_range"][0]
self.qmax = self.kwargs['int_range'][1] self.qmax = self.kwargs["int_range"][1]
else: else:
if self.sym: if self.sym:
self.qmin = -(2 ** (self.bit - 1)) self.qmin = -(2 ** (self.bit - 1))
...@@ -110,7 +110,14 @@ class IntegerQuantizer(BaseQuantizer): ...@@ -110,7 +110,14 @@ class IntegerQuantizer(BaseQuantizer):
tensor = (tensor - zeros) * scales tensor = (tensor - zeros) * scales
return tensor return tensor
def quant_dequant(self, tensor, scales, zeros, qmax, qmin,): def quant_dequant(
self,
tensor,
scales,
zeros,
qmax,
qmin,
):
tensor = self.quant(tensor, scales, zeros, qmax, qmin) tensor = self.quant(tensor, scales, zeros, qmax, qmin)
tensor = self.dequant(tensor, scales, zeros) tensor = self.dequant(tensor, scales, zeros)
return tensor return tensor
...@@ -119,19 +126,19 @@ class IntegerQuantizer(BaseQuantizer): ...@@ -119,19 +126,19 @@ class IntegerQuantizer(BaseQuantizer):
class FloatQuantizer(BaseQuantizer): class FloatQuantizer(BaseQuantizer):
def __init__(self, bit, symmetric, granularity, **kwargs): def __init__(self, bit, symmetric, granularity, **kwargs):
super().__init__(bit, symmetric, granularity, **kwargs) super().__init__(bit, symmetric, granularity, **kwargs)
assert self.bit in ['e4m3', 'e5m2'], f'Unsupported bit configuration: {self.bit}' assert self.bit in ["e4m3", "e5m2"], f"Unsupported bit configuration: {self.bit}"
assert self.sym == True assert self.sym
if self.bit == 'e4m3': if self.bit == "e4m3":
self.e_bits = 4 self.e_bits = 4
self.m_bits = 3 self.m_bits = 3
self.fp_dtype = torch.float8_e4m3fn self.fp_dtype = torch.float8_e4m3fn
elif self.bit == 'e5m2': elif self.bit == "e5m2":
self.e_bits = 5 self.e_bits = 5
self.m_bits = 2 self.m_bits = 2
self.fp_dtype = torch.float8_e5m2 self.fp_dtype = torch.float8_e5m2
else: else:
raise ValueError(f'Unsupported bit configuration: {self.bit}') raise ValueError(f"Unsupported bit configuration: {self.bit}")
finfo = torch.finfo(self.fp_dtype) finfo = torch.finfo(self.fp_dtype)
self.qmin, self.qmax = finfo.min, finfo.max self.qmin, self.qmax = finfo.min, finfo.max
...@@ -141,13 +148,9 @@ class FloatQuantizer(BaseQuantizer): ...@@ -141,13 +148,9 @@ class FloatQuantizer(BaseQuantizer):
def quant(self, tensor, scales, zeros, qmax, qmin): def quant(self, tensor, scales, zeros, qmax, qmin):
scaled_tensor = tensor / scales + zeros scaled_tensor = tensor / scales + zeros
scaled_tensor = torch.clip( scaled_tensor = torch.clip(scaled_tensor, self.qmin.cuda(), self.qmax.cuda())
scaled_tensor, self.qmin.cuda(), self.qmax.cuda()
)
org_dtype = scaled_tensor.dtype org_dtype = scaled_tensor.dtype
q_tensor = float_quantize( q_tensor = float_quantize(scaled_tensor.float(), self.e_bits, self.m_bits, rounding="nearest")
scaled_tensor.float(), self.e_bits, self.m_bits, rounding='nearest'
)
q_tensor.to(org_dtype) q_tensor.to(org_dtype)
return q_tensor return q_tensor
...@@ -161,21 +164,21 @@ class FloatQuantizer(BaseQuantizer): ...@@ -161,21 +164,21 @@ class FloatQuantizer(BaseQuantizer):
return tensor return tensor
if __name__ == '__main__': if __name__ == "__main__":
weight = torch.randn(4096, 4096, dtype=torch.bfloat16).cuda() weight = torch.randn(4096, 4096, dtype=torch.bfloat16).cuda()
quantizer = IntegerQuantizer(4, False, 'per_group', group_size=128) quantizer = IntegerQuantizer(4, False, "per_group", group_size=128)
q_weight = quantizer.fake_quant_tensor(weight) q_weight = quantizer.fake_quant_tensor(weight)
print(weight) print(weight)
print(q_weight) print(q_weight)
print(f"cosine = {torch.cosine_similarity(weight.view(1, -1).to(torch.float64), q_weight.view(1, -1).to(torch.float64))}") print(f"cosine = {torch.cosine_similarity(weight.view(1, -1).to(torch.float64), q_weight.view(1, -1).to(torch.float64))}")
realq_weight, scales, zeros = quantizer.real_quant_tensor(weight) realq_weight, scales, zeros = quantizer.real_quant_tensor(weight)
print(f"realq_weight = {realq_weight}, {realq_weight.shape}") print(f"realq_weight = {realq_weight}, {realq_weight.shape}")
print(f"scales = {scales}, {scales.shape}") print(f"scales = {scales}, {scales.shape}")
print(f"zeros = {zeros}, {zeros.shape}") print(f"zeros = {zeros}, {zeros.shape}")
weight = torch.randn(8192, 4096, dtype=torch.bfloat16).cuda() weight = torch.randn(8192, 4096, dtype=torch.bfloat16).cuda()
quantizer = FloatQuantizer('e4m3', True, 'per_channel') quantizer = FloatQuantizer("e4m3", True, "per_channel")
q_weight = quantizer.fake_quant_tensor(weight) q_weight = quantizer.fake_quant_tensor(weight)
print(weight) print(weight)
print(q_weight) print(q_weight)
......
...@@ -11,14 +11,14 @@ class Register(dict): ...@@ -11,14 +11,14 @@ class Register(dict):
def register(self, target, key=None): def register(self, target, key=None):
if not callable(target): if not callable(target):
raise Exception(f'Error: {target} must be callable!') raise Exception(f"Error: {target} must be callable!")
if key is None: if key is None:
key = target.__name__ key = target.__name__
if key in self._dict: if key in self._dict:
raise Exception(f'{key} already exists.') raise Exception(f"{key} already exists.")
self[key] = target self[key] = target
return target return target
......
...@@ -66,12 +66,7 @@ def cache_video( ...@@ -66,12 +66,7 @@ def cache_video(
# preprocess # preprocess
tensor = tensor.clamp(min(value_range), max(value_range)) tensor = tensor.clamp(min(value_range), max(value_range))
tensor = torch.stack( tensor = torch.stack(
[ [torchvision.utils.make_grid(u, nrow=nrow, normalize=normalize, value_range=value_range) for u in tensor.unbind(2)],
torchvision.utils.make_grid(
u, nrow=nrow, normalize=normalize, value_range=value_range
)
for u in tensor.unbind(2)
],
dim=1, dim=1,
).permute(1, 2, 3, 0) ).permute(1, 2, 3, 0)
tensor = (tensor * 255).type(torch.uint8).cpu() tensor = (tensor * 255).type(torch.uint8).cpu()
......
[tool.ruff]
exclude = [".git", ".mypy_cache", ".ruff_cache", ".venv", "dist"]
target-version = "py311"
line-length = 200
indent-width = 4
lint.ignore =["F"]
[tool.ruff.format]
line-ending = "lf"
quote-style = "double"
indent-style = "space"
...@@ -18,4 +18,4 @@ python ${lightx2v_path}/lightx2v/__main__.py \ ...@@ -18,4 +18,4 @@ python ${lightx2v_path}/lightx2v/__main__.py \
--target_width 1280 \ --target_width 1280 \
--attention_type flash_attn3 \ --attention_type flash_attn3 \
--save_video_path ./output_lightx2v_int8.mp4 \ --save_video_path ./output_lightx2v_int8.mp4 \
--mm_config '{"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm", "weight_auto_quant": true}' --mm_config '{"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm", "weight_auto_quant": true}'
...@@ -18,4 +18,4 @@ torchrun --nproc_per_node=4 ${lightx2v_path}/lightx2v/__main__.py \ ...@@ -18,4 +18,4 @@ torchrun --nproc_per_node=4 ${lightx2v_path}/lightx2v/__main__.py \
--target_width 1280 \ --target_width 1280 \
--attention_type flash_attn2 \ --attention_type flash_attn2 \
--mm_config '{"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm", "weight_auto_quant": true}' \ --mm_config '{"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm", "weight_auto_quant": true}' \
--parallel_attn --parallel_attn
\ No newline at end of file
...@@ -20,4 +20,4 @@ python ${lightx2v_path}/lightx2v/__main__.py \ ...@@ -20,4 +20,4 @@ python ${lightx2v_path}/lightx2v/__main__.py \
--cpu_offload \ --cpu_offload \
--feature_caching TaylorSeer \ --feature_caching TaylorSeer \
--save_video_path ./output_lightx2v_offload_TaylorSeer.mp4 \ --save_video_path ./output_lightx2v_offload_TaylorSeer.mp4 \
# --mm_config '{"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm", "weight_auto_quant": true}' # --mm_config '{"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm", "weight_auto_quant": true}'
\ No newline at end of file
...@@ -28,4 +28,4 @@ python ${lightx2v_path}/lightx2v/__main__.py \ ...@@ -28,4 +28,4 @@ python ${lightx2v_path}/lightx2v/__main__.py \
--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \ --image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \
--mm_config '{"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm", "weight_auto_quant": true}' \ --mm_config '{"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm", "weight_auto_quant": true}' \
# --feature_caching Tea \ # --feature_caching Tea \
# --use_ret_steps \ # --use_ret_steps \
\ No newline at end of file
...@@ -28,4 +28,4 @@ python ${lightx2v_path}/lightx2v/__main__.py \ ...@@ -28,4 +28,4 @@ python ${lightx2v_path}/lightx2v/__main__.py \
# --mm_config '{"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F", "weight_auto_quant": true}' \ # --mm_config '{"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F", "weight_auto_quant": true}' \
# --feature_caching Tea \ # --feature_caching Tea \
# --use_ret_steps \ # --use_ret_steps \
# --teacache_thresh 0.2 # --teacache_thresh 0.2
\ No newline at end of file
File mode changed from 100644 to 100755
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment