Commit a50bcc53 authored by Dongz's avatar Dongz Committed by Yang Yong(雍洋)
Browse files

add lint feature and minor fix (#7)

* [minor]: optimize dockerfile for fewer layer

* [feature]: add pre-commit lint, update readme for contribution guidance

* [minor]: fix run shell privileges

* [auto]: first lint without rule F, fix rule E

* [minor]: fix docker file error
parent 3b460075
......@@ -39,28 +39,15 @@ def init_weights(m):
nn.init.normal_(m.v.weight, std=m.dim**-0.5)
nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn) ** -0.5)
elif isinstance(m, T5RelativeEmbedding):
nn.init.normal_(
m.embedding.weight, std=(2 * m.num_buckets * m.num_heads) ** -0.5
)
nn.init.normal_(m.embedding.weight, std=(2 * m.num_buckets * m.num_heads) ** -0.5)
class GELU(nn.Module):
def forward(self, x):
return (
0.5
* x
* (
1.0
+ torch.tanh(
math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))
)
)
)
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
class T5LayerNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super(T5LayerNorm, self).__init__()
self.dim = dim
......@@ -75,7 +62,6 @@ class T5LayerNorm(nn.Module):
class T5Attention(nn.Module):
def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
assert dim_attn % num_heads == 0
super(T5Attention, self).__init__()
......@@ -128,7 +114,6 @@ class T5Attention(nn.Module):
class T5FeedForward(nn.Module):
def __init__(self, dim, dim_ffn, dropout=0.1):
super(T5FeedForward, self).__init__()
self.dim = dim
......@@ -149,7 +134,6 @@ class T5FeedForward(nn.Module):
class T5SelfAttention(nn.Module):
def __init__(
self,
dim,
......@@ -173,11 +157,7 @@ class T5SelfAttention(nn.Module):
self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
self.norm2 = T5LayerNorm(dim)
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
self.pos_embedding = (
None
if shared_pos
else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
)
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
def forward(self, x, mask=None, pos_bias=None):
e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
......@@ -187,7 +167,6 @@ class T5SelfAttention(nn.Module):
class T5CrossAttention(nn.Module):
def __init__(
self,
dim,
......@@ -213,27 +192,17 @@ class T5CrossAttention(nn.Module):
self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
self.norm3 = T5LayerNorm(dim)
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
self.pos_embedding = (
None
if shared_pos
else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False)
)
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False)
def forward(
self, x, mask=None, encoder_states=None, encoder_mask=None, pos_bias=None
):
def forward(self, x, mask=None, encoder_states=None, encoder_mask=None, pos_bias=None):
e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
x = fp16_clamp(
x
+ self.cross_attn(self.norm2(x), context=encoder_states, mask=encoder_mask)
)
x = fp16_clamp(x + self.cross_attn(self.norm2(x), context=encoder_states, mask=encoder_mask))
x = fp16_clamp(x + self.ffn(self.norm3(x)))
return x
class T5RelativeEmbedding(nn.Module):
def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
super(T5RelativeEmbedding, self).__init__()
self.num_buckets = num_buckets
......@@ -248,9 +217,7 @@ class T5RelativeEmbedding(nn.Module):
device = self.embedding.weight.device
# rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
# torch.arange(lq).unsqueeze(1).to(device)
rel_pos = torch.arange(lk, device=device).unsqueeze(0) - torch.arange(
lq, device=device
).unsqueeze(1)
rel_pos = torch.arange(lk, device=device).unsqueeze(0) - torch.arange(lq, device=device).unsqueeze(1)
rel_pos = self._relative_position_bucket(rel_pos)
rel_pos_embeds = self.embedding(rel_pos)
rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(0) # [1, N, Lq, Lk]
......@@ -269,23 +236,13 @@ class T5RelativeEmbedding(nn.Module):
# embeddings for small and large positions
max_exact = num_buckets // 2
rel_pos_large = (
max_exact
+ (
torch.log(rel_pos.float() / max_exact)
/ math.log(self.max_dist / max_exact)
* (num_buckets - max_exact)
).long()
)
rel_pos_large = torch.min(
rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1)
)
rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) / math.log(self.max_dist / max_exact) * (num_buckets - max_exact)).long()
rel_pos_large = torch.min(rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
return rel_buckets
class T5Encoder(nn.Module):
def __init__(
self,
vocab,
......@@ -308,23 +265,10 @@ class T5Encoder(nn.Module):
self.shared_pos = shared_pos
# layers
self.token_embedding = (
vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
)
self.pos_embedding = (
T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
if shared_pos
else None
)
self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True) if shared_pos else None
self.dropout = nn.Dropout(dropout)
self.blocks = nn.ModuleList(
[
T5SelfAttention(
dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout
)
for _ in range(num_layers)
]
)
self.blocks = nn.ModuleList([T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout) for _ in range(num_layers)])
self.norm = T5LayerNorm(dim)
# initialize weights
......@@ -342,7 +286,6 @@ class T5Encoder(nn.Module):
class T5Decoder(nn.Module):
def __init__(
self,
vocab,
......@@ -365,23 +308,10 @@ class T5Decoder(nn.Module):
self.shared_pos = shared_pos
# layers
self.token_embedding = (
vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
)
self.pos_embedding = (
T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False)
if shared_pos
else None
)
self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False) if shared_pos else None
self.dropout = nn.Dropout(dropout)
self.blocks = nn.ModuleList(
[
T5CrossAttention(
dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout
)
for _ in range(num_layers)
]
)
self.blocks = nn.ModuleList([T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout) for _ in range(num_layers)])
self.norm = T5LayerNorm(dim)
# initialize weights
......@@ -408,7 +338,6 @@ class T5Decoder(nn.Module):
class T5Model(nn.Module):
def __init__(
self,
vocab_size,
......@@ -530,7 +459,6 @@ def umt5_xxl(**kwargs):
class T5EncoderModel:
def __init__(
self,
text_len,
......@@ -547,13 +475,7 @@ class T5EncoderModel:
self.tokenizer_path = tokenizer_path
# init model
model = (
umt5_xxl(
encoder_only=True, return_tokenizer=False, dtype=dtype, device=device
)
.eval()
.requires_grad_(False)
)
model = umt5_xxl(encoder_only=True, return_tokenizer=False, dtype=dtype, device=device).eval().requires_grad_(False)
logging.info(f"loading {checkpoint_path}")
model.load_state_dict(torch.load(checkpoint_path, map_location="cpu", weights_only=True))
self.model = model
......@@ -562,9 +484,7 @@ class T5EncoderModel:
else:
self.model.to(self.device)
# init tokenizer
self.tokenizer = HuggingfaceTokenizer(
name=tokenizer_path, seq_len=text_len, clean="whitespace"
)
self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=text_len, clean="whitespace")
def to_cpu(self):
self.model = self.model.to("cpu")
......
......@@ -24,10 +24,7 @@ def whitespace_clean(text):
def canonicalize(text, keep_punctuation_exact_string=None):
text = text.replace("_", " ")
if keep_punctuation_exact_string:
text = keep_punctuation_exact_string.join(
part.translate(str.maketrans("", "", string.punctuation))
for part in text.split(keep_punctuation_exact_string)
)
text = keep_punctuation_exact_string.join(part.translate(str.maketrans("", "", string.punctuation)) for part in text.split(keep_punctuation_exact_string))
else:
text = text.translate(str.maketrans("", "", string.punctuation))
text = text.lower()
......@@ -36,7 +33,6 @@ def canonicalize(text, keep_punctuation_exact_string=None):
class HuggingfaceTokenizer:
def __init__(self, name, seq_len=None, clean=None, **kwargs):
assert clean in (None, "whitespace", "lower", "canonicalize")
self.name = name
......
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#
# Modified from diffusers==0.29.2
#
# ==============================================================================
from typing import Dict, Optional, Tuple, Union
from dataclasses import dataclass
import torch
import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
try:
# This diffusers is modified and packed in the mirror.
from diffusers.loaders import FromOriginalVAEMixin
except ImportError:
# Use this to be compatible with the original diffusers.
from diffusers.loaders.single_file_model import FromOriginalModelMixin as FromOriginalVAEMixin
from diffusers.utils.accelerate_utils import apply_forward_hook
from diffusers.models.attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
Attention,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
)
from diffusers.models.modeling_outputs import AutoencoderKLOutput
from diffusers.models.modeling_utils import ModelMixin
from .vae import DecoderCausal3D, BaseOutput, DecoderOutput, DiagonalGaussianDistribution, EncoderCausal3D
@dataclass
class DecoderOutput2(BaseOutput):
sample: torch.FloatTensor
posterior: Optional[DiagonalGaussianDistribution] = None
class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
r"""
A VAE model with KL loss for encoding images/videos into latents and decoding latent representations into images/videos.
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str] = ("DownEncoderBlockCausal3D",),
up_block_types: Tuple[str] = ("UpDecoderBlockCausal3D",),
block_out_channels: Tuple[int] = (64,),
layers_per_block: int = 1,
act_fn: str = "silu",
latent_channels: int = 4,
norm_num_groups: int = 32,
sample_size: int = 32,
sample_tsize: int = 64,
scaling_factor: float = 0.18215,
force_upcast: float = True,
spatial_compression_ratio: int = 8,
time_compression_ratio: int = 4,
mid_block_add_attention: bool = True,
):
super().__init__()
self.time_compression_ratio = time_compression_ratio
self.encoder = EncoderCausal3D(
in_channels=in_channels,
out_channels=latent_channels,
down_block_types=down_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
double_z=True,
time_compression_ratio=time_compression_ratio,
spatial_compression_ratio=spatial_compression_ratio,
mid_block_add_attention=mid_block_add_attention,
)
self.decoder = DecoderCausal3D(
in_channels=latent_channels,
out_channels=out_channels,
up_block_types=up_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
norm_num_groups=norm_num_groups,
act_fn=act_fn,
time_compression_ratio=time_compression_ratio,
spatial_compression_ratio=spatial_compression_ratio,
mid_block_add_attention=mid_block_add_attention,
)
self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, kernel_size=1)
self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, kernel_size=1)
self.use_slicing = False
self.use_spatial_tiling = False
self.use_temporal_tiling = False
# only relevant if vae tiling is enabled
self.tile_sample_min_tsize = sample_tsize
self.tile_latent_min_tsize = sample_tsize // time_compression_ratio
self.tile_sample_min_size = self.config.sample_size
sample_size = (
self.config.sample_size[0]
if isinstance(self.config.sample_size, (list, tuple))
else self.config.sample_size
)
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor = 0.25
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (EncoderCausal3D, DecoderCausal3D)):
module.gradient_checkpointing = value
def enable_temporal_tiling(self, use_tiling: bool = True):
self.use_temporal_tiling = use_tiling
def disable_temporal_tiling(self):
self.enable_temporal_tiling(False)
def enable_spatial_tiling(self, use_tiling: bool = True):
self.use_spatial_tiling = use_tiling
def disable_spatial_tiling(self):
self.enable_spatial_tiling(False)
def enable_tiling(self, use_tiling: bool = True):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger videos.
"""
self.enable_spatial_tiling(use_tiling)
self.enable_temporal_tiling(use_tiling)
def disable_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.disable_spatial_tiling()
self.disable_temporal_tiling()
def enable_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
@property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor, _remove_lora=_remove_lora)
else:
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnAddedKVProcessor()
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnProcessor()
else:
raise ValueError(
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
self.set_attn_processor(processor, _remove_lora=True)
@apply_forward_hook
def encode(
self, x: torch.FloatTensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
"""
Encode a batch of images/videos into latents.
Args:
x (`torch.FloatTensor`): Input batch of images/videos.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded images/videos. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
assert len(x.shape) == 5, "The input tensor should have 5 dimensions."
if self.use_temporal_tiling and x.shape[2] > self.tile_sample_min_tsize:
return self.temporal_tiled_encode(x, return_dict=return_dict)
if self.use_spatial_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
return self.spatial_tiled_encode(x, return_dict=return_dict)
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
assert len(z.shape) == 5, "The input tensor should have 5 dimensions."
if self.use_temporal_tiling and z.shape[2] > self.tile_latent_min_tsize:
return self.temporal_tiled_decode(z, return_dict=return_dict)
if self.use_spatial_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
return self.spatial_tiled_decode(z, return_dict=return_dict)
z = self.post_quant_conv(z)
dec = self.decoder(z)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
@apply_forward_hook
def decode(
self, z: torch.FloatTensor, return_dict: bool = True, generator=None
) -> Union[DecoderOutput, torch.FloatTensor]:
"""
Decode a batch of images/videos.
Args:
z (`torch.FloatTensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z).sample
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
for y in range(blend_extent):
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent)
return b
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
for x in range(blend_extent):
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent)
return b
def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
for x in range(blend_extent):
b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (x / blend_extent)
return b
def spatial_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True, return_moments: bool = False) -> AutoencoderKLOutput:
r"""Encode a batch of images/videos using a tiled encoder.
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
steps. This is useful to keep memory use constant regardless of image/videos size. The end result of tiled encoding is
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
output, but they should be much less noticeable.
Args:
x (`torch.FloatTensor`): Input batch of images/videos.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
[`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
`tuple` is returned.
"""
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
row_limit = self.tile_latent_min_size - blend_extent
# Split video into tiles and encode them separately.
rows = []
for i in range(0, x.shape[-2], overlap_size):
row = []
for j in range(0, x.shape[-1], overlap_size):
tile = x[:, :, :, i: i + self.tile_sample_min_size, j: j + self.tile_sample_min_size]
tile = self.encoder(tile)
tile = self.quant_conv(tile)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=-1))
moments = torch.cat(result_rows, dim=-2)
if return_moments:
return moments
posterior = DiagonalGaussianDistribution(moments)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def spatial_tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
r"""
Decode a batch of images/videos using a tiled decoder.
Args:
z (`torch.FloatTensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
row_limit = self.tile_sample_min_size - blend_extent
# Split z into overlapping tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, z.shape[-2], overlap_size):
row = []
for j in range(0, z.shape[-1], overlap_size):
tile = z[:, :, :, i: i + self.tile_latent_min_size, j: j + self.tile_latent_min_size]
tile = self.post_quant_conv(tile)
decoded = self.decoder(tile)
row.append(decoded)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=-1))
dec = torch.cat(result_rows, dim=-2)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
def temporal_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
B, C, T, H, W = x.shape
overlap_size = int(self.tile_sample_min_tsize * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_latent_min_tsize * self.tile_overlap_factor)
t_limit = self.tile_latent_min_tsize - blend_extent
# Split the video into tiles and encode them separately.
row = []
for i in range(0, T, overlap_size):
tile = x[:, :, i: i + self.tile_sample_min_tsize + 1, :, :]
if self.use_spatial_tiling and (tile.shape[-1] > self.tile_sample_min_size or tile.shape[-2] > self.tile_sample_min_size):
tile = self.spatial_tiled_encode(tile, return_moments=True)
else:
tile = self.encoder(tile)
tile = self.quant_conv(tile)
if i > 0:
tile = tile[:, :, 1:, :, :]
row.append(tile)
result_row = []
for i, tile in enumerate(row):
if i > 0:
tile = self.blend_t(row[i - 1], tile, blend_extent)
result_row.append(tile[:, :, :t_limit, :, :])
else:
result_row.append(tile[:, :, :t_limit + 1, :, :])
moments = torch.cat(result_row, dim=2)
posterior = DiagonalGaussianDistribution(moments)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def temporal_tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
# Split z into overlapping tiles and decode them separately.
B, C, T, H, W = z.shape
overlap_size = int(self.tile_latent_min_tsize * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_sample_min_tsize * self.tile_overlap_factor)
t_limit = self.tile_sample_min_tsize - blend_extent
row = []
for i in range(0, T, overlap_size):
tile = z[:, :, i: i + self.tile_latent_min_tsize + 1, :, :]
if self.use_spatial_tiling and (tile.shape[-1] > self.tile_latent_min_size or tile.shape[-2] > self.tile_latent_min_size):
decoded = self.spatial_tiled_decode(tile, return_dict=True).sample
else:
tile = self.post_quant_conv(tile)
decoded = self.decoder(tile)
if i > 0:
decoded = decoded[:, :, 1:, :, :]
row.append(decoded)
result_row = []
for i, tile in enumerate(row):
if i > 0:
tile = self.blend_t(row[i - 1], tile, blend_extent)
result_row.append(tile[:, :, :t_limit, :, :])
else:
result_row.append(tile[:, :, :t_limit + 1, :, :])
dec = torch.cat(result_row, dim=2)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
def forward(
self,
sample: torch.FloatTensor,
sample_posterior: bool = False,
return_dict: bool = True,
return_posterior: bool = False,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput2, torch.FloatTensor]:
r"""
Args:
sample (`torch.FloatTensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z).sample
if not return_dict:
if return_posterior:
return (dec, posterior)
else:
return (dec,)
if return_posterior:
return DecoderOutput2(sample=dec, posterior=posterior)
else:
return DecoderOutput2(sample=dec)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
self.original_attn_processors = None
for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
self.original_attn_processors = self.attn_processors
for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#
# Modified from diffusers==0.29.2
#
# ==============================================================================
from typing import Dict, Optional, Tuple, Union
from dataclasses import dataclass
import torch
import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
try:
# This diffusers is modified and packed in the mirror.
from diffusers.loaders import FromOriginalVAEMixin
except ImportError:
# Use this to be compatible with the original diffusers.
from diffusers.loaders.single_file_model import FromOriginalModelMixin as FromOriginalVAEMixin
from diffusers.utils.accelerate_utils import apply_forward_hook
from diffusers.models.attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
Attention,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
)
from diffusers.models.modeling_outputs import AutoencoderKLOutput
from diffusers.models.modeling_utils import ModelMixin
from .vae import DecoderCausal3D, BaseOutput, DecoderOutput, DiagonalGaussianDistribution, EncoderCausal3D
@dataclass
class DecoderOutput2(BaseOutput):
sample: torch.FloatTensor
posterior: Optional[DiagonalGaussianDistribution] = None
class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
r"""
A VAE model with KL loss for encoding images/videos into latents and decoding latent representations into images/videos.
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str] = ("DownEncoderBlockCausal3D",),
up_block_types: Tuple[str] = ("UpDecoderBlockCausal3D",),
block_out_channels: Tuple[int] = (64,),
layers_per_block: int = 1,
act_fn: str = "silu",
latent_channels: int = 4,
norm_num_groups: int = 32,
sample_size: int = 32,
sample_tsize: int = 64,
scaling_factor: float = 0.18215,
force_upcast: float = True,
spatial_compression_ratio: int = 8,
time_compression_ratio: int = 4,
mid_block_add_attention: bool = True,
):
super().__init__()
self.time_compression_ratio = time_compression_ratio
self.encoder = EncoderCausal3D(
in_channels=in_channels,
out_channels=latent_channels,
down_block_types=down_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
double_z=True,
time_compression_ratio=time_compression_ratio,
spatial_compression_ratio=spatial_compression_ratio,
mid_block_add_attention=mid_block_add_attention,
)
self.decoder = DecoderCausal3D(
in_channels=latent_channels,
out_channels=out_channels,
up_block_types=up_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
norm_num_groups=norm_num_groups,
act_fn=act_fn,
time_compression_ratio=time_compression_ratio,
spatial_compression_ratio=spatial_compression_ratio,
mid_block_add_attention=mid_block_add_attention,
)
self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, kernel_size=1)
self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, kernel_size=1)
self.use_slicing = False
self.use_spatial_tiling = False
self.use_temporal_tiling = False
# only relevant if vae tiling is enabled
self.tile_sample_min_tsize = sample_tsize
self.tile_latent_min_tsize = sample_tsize // time_compression_ratio
self.tile_sample_min_size = self.config.sample_size
sample_size = self.config.sample_size[0] if isinstance(self.config.sample_size, (list, tuple)) else self.config.sample_size
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor = 0.25
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (EncoderCausal3D, DecoderCausal3D)):
module.gradient_checkpointing = value
def enable_temporal_tiling(self, use_tiling: bool = True):
self.use_temporal_tiling = use_tiling
def disable_temporal_tiling(self):
self.enable_temporal_tiling(False)
def enable_spatial_tiling(self, use_tiling: bool = True):
self.use_spatial_tiling = use_tiling
def disable_spatial_tiling(self):
self.enable_spatial_tiling(False)
def enable_tiling(self, use_tiling: bool = True):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger videos.
"""
self.enable_spatial_tiling(use_tiling)
self.enable_temporal_tiling(use_tiling)
def disable_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.disable_spatial_tiling()
self.disable_temporal_tiling()
def enable_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
@property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor, _remove_lora=_remove_lora)
else:
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnAddedKVProcessor()
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnProcessor()
else:
raise ValueError(f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}")
self.set_attn_processor(processor, _remove_lora=True)
@apply_forward_hook
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
"""
Encode a batch of images/videos into latents.
Args:
x (`torch.FloatTensor`): Input batch of images/videos.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded images/videos. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
assert len(x.shape) == 5, "The input tensor should have 5 dimensions."
if self.use_temporal_tiling and x.shape[2] > self.tile_sample_min_tsize:
return self.temporal_tiled_encode(x, return_dict=return_dict)
if self.use_spatial_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
return self.spatial_tiled_encode(x, return_dict=return_dict)
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
assert len(z.shape) == 5, "The input tensor should have 5 dimensions."
if self.use_temporal_tiling and z.shape[2] > self.tile_latent_min_tsize:
return self.temporal_tiled_decode(z, return_dict=return_dict)
if self.use_spatial_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
return self.spatial_tiled_decode(z, return_dict=return_dict)
z = self.post_quant_conv(z)
dec = self.decoder(z)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
@apply_forward_hook
def decode(self, z: torch.FloatTensor, return_dict: bool = True, generator=None) -> Union[DecoderOutput, torch.FloatTensor]:
"""
Decode a batch of images/videos.
Args:
z (`torch.FloatTensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z).sample
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
for y in range(blend_extent):
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent)
return b
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
for x in range(blend_extent):
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent)
return b
def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
for x in range(blend_extent):
b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (x / blend_extent)
return b
def spatial_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True, return_moments: bool = False) -> AutoencoderKLOutput:
r"""Encode a batch of images/videos using a tiled encoder.
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
steps. This is useful to keep memory use constant regardless of image/videos size. The end result of tiled encoding is
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
output, but they should be much less noticeable.
Args:
x (`torch.FloatTensor`): Input batch of images/videos.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
[`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
`tuple` is returned.
"""
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
row_limit = self.tile_latent_min_size - blend_extent
# Split video into tiles and encode them separately.
rows = []
for i in range(0, x.shape[-2], overlap_size):
row = []
for j in range(0, x.shape[-1], overlap_size):
tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
tile = self.encoder(tile)
tile = self.quant_conv(tile)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=-1))
moments = torch.cat(result_rows, dim=-2)
if return_moments:
return moments
posterior = DiagonalGaussianDistribution(moments)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def spatial_tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
r"""
Decode a batch of images/videos using a tiled decoder.
Args:
z (`torch.FloatTensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
row_limit = self.tile_sample_min_size - blend_extent
# Split z into overlapping tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, z.shape[-2], overlap_size):
row = []
for j in range(0, z.shape[-1], overlap_size):
tile = z[:, :, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
tile = self.post_quant_conv(tile)
decoded = self.decoder(tile)
row.append(decoded)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=-1))
dec = torch.cat(result_rows, dim=-2)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
def temporal_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
B, C, T, H, W = x.shape
overlap_size = int(self.tile_sample_min_tsize * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_latent_min_tsize * self.tile_overlap_factor)
t_limit = self.tile_latent_min_tsize - blend_extent
# Split the video into tiles and encode them separately.
row = []
for i in range(0, T, overlap_size):
tile = x[:, :, i : i + self.tile_sample_min_tsize + 1, :, :]
if self.use_spatial_tiling and (tile.shape[-1] > self.tile_sample_min_size or tile.shape[-2] > self.tile_sample_min_size):
tile = self.spatial_tiled_encode(tile, return_moments=True)
else:
tile = self.encoder(tile)
tile = self.quant_conv(tile)
if i > 0:
tile = tile[:, :, 1:, :, :]
row.append(tile)
result_row = []
for i, tile in enumerate(row):
if i > 0:
tile = self.blend_t(row[i - 1], tile, blend_extent)
result_row.append(tile[:, :, :t_limit, :, :])
else:
result_row.append(tile[:, :, : t_limit + 1, :, :])
moments = torch.cat(result_row, dim=2)
posterior = DiagonalGaussianDistribution(moments)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def temporal_tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
# Split z into overlapping tiles and decode them separately.
B, C, T, H, W = z.shape
overlap_size = int(self.tile_latent_min_tsize * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_sample_min_tsize * self.tile_overlap_factor)
t_limit = self.tile_sample_min_tsize - blend_extent
row = []
for i in range(0, T, overlap_size):
tile = z[:, :, i : i + self.tile_latent_min_tsize + 1, :, :]
if self.use_spatial_tiling and (tile.shape[-1] > self.tile_latent_min_size or tile.shape[-2] > self.tile_latent_min_size):
decoded = self.spatial_tiled_decode(tile, return_dict=True).sample
else:
tile = self.post_quant_conv(tile)
decoded = self.decoder(tile)
if i > 0:
decoded = decoded[:, :, 1:, :, :]
row.append(decoded)
result_row = []
for i, tile in enumerate(row):
if i > 0:
tile = self.blend_t(row[i - 1], tile, blend_extent)
result_row.append(tile[:, :, :t_limit, :, :])
else:
result_row.append(tile[:, :, : t_limit + 1, :, :])
dec = torch.cat(result_row, dim=2)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
def forward(
self,
sample: torch.FloatTensor,
sample_posterior: bool = False,
return_dict: bool = True,
return_posterior: bool = False,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput2, torch.FloatTensor]:
r"""
Args:
sample (`torch.FloatTensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z).sample
if not return_dict:
if return_posterior:
return (dec, posterior)
else:
return (dec,)
if return_posterior:
return DecoderOutput2(sample=dec, posterior=posterior)
else:
return DecoderOutput2(sample=dec)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
self.original_attn_processors = None
for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
self.original_attn_processors = self.attn_processors
for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
......@@ -3,7 +3,7 @@ import torch
from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D
class VideoEncoderKLCausal3DModel():
class VideoEncoderKLCausal3DModel:
def __init__(self, model_path, dtype, device):
self.model_path = model_path
self.dtype = dtype
......@@ -11,10 +11,10 @@ class VideoEncoderKLCausal3DModel():
self.load()
def load(self):
self.vae_path = os.path.join(self.model_path, 'hunyuan-video-t2v-720p/vae')
self.vae_path = os.path.join(self.model_path, "hunyuan-video-t2v-720p/vae")
config = AutoencoderKLCausal3D.load_config(self.vae_path)
self.model = AutoencoderKLCausal3D.from_config(config)
ckpt = torch.load(os.path.join(self.vae_path, 'pytorch_model.pt'), map_location='cpu', weights_only=True)
ckpt = torch.load(os.path.join(self.vae_path, "pytorch_model.pt"), map_location="cpu", weights_only=True)
self.model.load_state_dict(ckpt)
self.model = self.model.to(dtype=self.dtype, device=self.device)
self.model.requires_grad_(False)
......@@ -32,14 +32,13 @@ class VideoEncoderKLCausal3DModel():
latents = latents / self.model.config.scaling_factor
latents = latents.to(dtype=self.dtype, device=torch.device("cuda"))
self.model.enable_tiling()
image = self.model.decode(
latents, return_dict=False, generator=generator
)[0]
image = self.model.decode(latents, return_dict=False, generator=generator)[0]
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().float()
if args.cpu_offload:
self.to_cpu()
return image
if __name__ == "__main__":
vae_model = VideoEncoderKLCausal3DModel("/mnt/nvme0/yongyang/projects/hy/new/HunyuanVideo/ckpts", dtype=torch.float16, device=torch.device("cuda"))
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#
# Modified from diffusers==0.29.2
#
# ==============================================================================
from typing import Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
from einops import rearrange
from diffusers.utils import logging
from diffusers.models.activations import get_activation
from diffusers.models.attention_processor import SpatialNorm
from diffusers.models.attention_processor import Attention
from diffusers.models.normalization import AdaGroupNorm
from diffusers.models.normalization import RMSNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def prepare_causal_attention_mask_ori(n_frame: int, n_hw: int, dtype, device, batch_size: int = None):
seq_len = n_frame * n_hw
mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device)
for i in range(seq_len):
i_frame = i // n_hw
mask[i, : (i_frame + 1) * n_hw] = 0
if batch_size is not None:
mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
return mask
def prepare_causal_attention_mask(n_frame: int, n_hw: int, dtype, device, batch_size: int = None):
seq_len = n_frame * n_hw
mask = torch.full((n_frame, n_frame, n_hw, n_hw), float("-inf"), dtype=dtype, device=device)
# mask = mask.reshape(n_frame, n_frame, n_hw, n_hw)
idx_arr = torch.tril(torch.ones(n_frame, n_frame, dtype=dtype, device=device))
idx_arr = idx_arr > torch.zeros_like(idx_arr)
for i in range(n_frame):
for j in range(n_frame):
if idx_arr[i,j]:
mask[i, j] = torch.zeros(n_hw, n_hw, dtype=dtype, device=device)
# mask[idx_arr] = torch.zeros(n_hw, n_hw, dtype=dtype, device=device)
mask = mask.view(n_frame, -1, n_hw).transpose(1, 0).reshape(seq_len, -1).transpose(1,0)
if batch_size is not None:
mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
return mask.to(device)
class CausalConv3d(nn.Module):
"""
Implements a causal 3D convolution layer where each position only depends on previous timesteps and current spatial locations.
This maintains temporal causality in video generation tasks.
"""
def __init__(
self,
chan_in,
chan_out,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Union[int, Tuple[int, int, int]] = 1,
dilation: Union[int, Tuple[int, int, int]] = 1,
pad_mode='replicate',
**kwargs
):
super().__init__()
self.pad_mode = pad_mode
padding = (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size - 1, 0) # W, H, T
self.time_causal_padding = padding
self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
def forward(self, x):
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
return self.conv(x)
class UpsampleCausal3D(nn.Module):
"""
A 3D upsampling layer with an optional convolution.
"""
def __init__(
self,
channels: int,
use_conv: bool = False,
use_conv_transpose: bool = False,
out_channels: Optional[int] = None,
name: str = "conv",
kernel_size: Optional[int] = None,
padding=1,
norm_type=None,
eps=None,
elementwise_affine=None,
bias=True,
interpolate=True,
upsample_factor=(2, 2, 2),
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_conv_transpose = use_conv_transpose
self.name = name
self.interpolate = interpolate
self.upsample_factor = upsample_factor
if norm_type == "ln_norm":
self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
elif norm_type == "rms_norm":
self.norm = RMSNorm(channels, eps, elementwise_affine)
elif norm_type is None:
self.norm = None
else:
raise ValueError(f"unknown norm_type: {norm_type}")
conv = None
if use_conv_transpose:
raise NotImplementedError
elif use_conv:
if kernel_size is None:
kernel_size = 3
conv = CausalConv3d(self.channels, self.out_channels, kernel_size=kernel_size, bias=bias)
if name == "conv":
self.conv = conv
else:
self.Conv2d_0 = conv
def forward(
self,
hidden_states: torch.FloatTensor,
output_size: Optional[int] = None,
scale: float = 1.0,
) -> torch.FloatTensor:
assert hidden_states.shape[1] == self.channels
if self.norm is not None:
raise NotImplementedError
if self.use_conv_transpose:
return self.conv(hidden_states)
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
dtype = hidden_states.dtype
if dtype == torch.bfloat16:
hidden_states = hidden_states.to(torch.float32)
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
if hidden_states.shape[0] >= 64:
hidden_states = hidden_states.contiguous()
# if `output_size` is passed we force the interpolation output
# size and do not make use of `scale_factor=2`
if self.interpolate:
B, C, T, H, W = hidden_states.shape
first_h, other_h = hidden_states.split((1, T - 1), dim=2)
if output_size is None:
if T > 1:
other_h = F.interpolate(other_h, scale_factor=self.upsample_factor, mode="nearest")
# first_h = first_h.squeeze(2)
first_h = first_h.view(B, C, H, W)
first_h = F.interpolate(first_h, scale_factor=self.upsample_factor[1:], mode="nearest")
first_h = first_h.unsqueeze(2)
else:
raise NotImplementedError
if T > 1:
hidden_states = torch.cat((first_h, other_h), dim=2)
else:
hidden_states = first_h
# If the input is bfloat16, we cast back to bfloat16
if dtype == torch.bfloat16:
hidden_states = hidden_states.to(dtype)
if self.use_conv:
if self.name == "conv":
hidden_states = self.conv(hidden_states)
else:
hidden_states = self.Conv2d_0(hidden_states)
return hidden_states
class DownsampleCausal3D(nn.Module):
"""
A 3D downsampling layer with an optional convolution.
"""
def __init__(
self,
channels: int,
use_conv: bool = False,
out_channels: Optional[int] = None,
padding: int = 1,
name: str = "conv",
kernel_size=3,
norm_type=None,
eps=None,
elementwise_affine=None,
bias=True,
stride=2,
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.padding = padding
stride = stride
self.name = name
if norm_type == "ln_norm":
self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
elif norm_type == "rms_norm":
self.norm = RMSNorm(channels, eps, elementwise_affine)
elif norm_type is None:
self.norm = None
else:
raise ValueError(f"unknown norm_type: {norm_type}")
if use_conv:
conv = CausalConv3d(
self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, bias=bias
)
else:
raise NotImplementedError
if name == "conv":
self.Conv2d_0 = conv
self.conv = conv
elif name == "Conv2d_0":
self.conv = conv
else:
self.conv = conv
def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
assert hidden_states.shape[1] == self.channels
if self.norm is not None:
hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
assert hidden_states.shape[1] == self.channels
hidden_states = self.conv(hidden_states)
return hidden_states
class ResnetBlockCausal3D(nn.Module):
r"""
A Resnet block.
"""
def __init__(
self,
*,
in_channels: int,
out_channels: Optional[int] = None,
conv_shortcut: bool = False,
dropout: float = 0.0,
temb_channels: int = 512,
groups: int = 32,
groups_out: Optional[int] = None,
pre_norm: bool = True,
eps: float = 1e-6,
non_linearity: str = "swish",
skip_time_act: bool = False,
# default, scale_shift, ada_group, spatial
time_embedding_norm: str = "default",
kernel: Optional[torch.FloatTensor] = None,
output_scale_factor: float = 1.0,
use_in_shortcut: Optional[bool] = None,
up: bool = False,
down: bool = False,
conv_shortcut_bias: bool = True,
conv_3d_out_channels: Optional[int] = None,
):
super().__init__()
self.pre_norm = pre_norm
self.pre_norm = True
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.up = up
self.down = down
self.output_scale_factor = output_scale_factor
self.time_embedding_norm = time_embedding_norm
self.skip_time_act = skip_time_act
linear_cls = nn.Linear
if groups_out is None:
groups_out = groups
if self.time_embedding_norm == "ada_group":
self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
elif self.time_embedding_norm == "spatial":
self.norm1 = SpatialNorm(in_channels, temb_channels)
else:
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, stride=1)
if temb_channels is not None:
if self.time_embedding_norm == "default":
self.time_emb_proj = linear_cls(temb_channels, out_channels)
elif self.time_embedding_norm == "scale_shift":
self.time_emb_proj = linear_cls(temb_channels, 2 * out_channels)
elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
self.time_emb_proj = None
else:
raise ValueError(f"Unknown time_embedding_norm : {self.time_embedding_norm} ")
else:
self.time_emb_proj = None
if self.time_embedding_norm == "ada_group":
self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
elif self.time_embedding_norm == "spatial":
self.norm2 = SpatialNorm(out_channels, temb_channels)
else:
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
self.dropout = torch.nn.Dropout(dropout)
conv_3d_out_channels = conv_3d_out_channels or out_channels
self.conv2 = CausalConv3d(out_channels, conv_3d_out_channels, kernel_size=3, stride=1)
self.nonlinearity = get_activation(non_linearity)
self.upsample = self.downsample = None
if self.up:
self.upsample = UpsampleCausal3D(in_channels, use_conv=False)
elif self.down:
self.downsample = DownsampleCausal3D(in_channels, use_conv=False, name="op")
self.use_in_shortcut = self.in_channels != conv_3d_out_channels if use_in_shortcut is None else use_in_shortcut
self.conv_shortcut = None
if self.use_in_shortcut:
self.conv_shortcut = CausalConv3d(
in_channels,
conv_3d_out_channels,
kernel_size=1,
stride=1,
bias=conv_shortcut_bias,
)
def forward(
self,
input_tensor: torch.FloatTensor,
temb: torch.FloatTensor,
scale: float = 1.0,
) -> torch.FloatTensor:
hidden_states = input_tensor
if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
hidden_states = self.norm1(hidden_states, temb)
else:
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
if self.upsample is not None:
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
if hidden_states.shape[0] >= 64:
input_tensor = input_tensor.contiguous()
hidden_states = hidden_states.contiguous()
input_tensor = (
self.upsample(input_tensor, scale=scale)
)
hidden_states = (
self.upsample(hidden_states, scale=scale)
)
elif self.downsample is not None:
input_tensor = (
self.downsample(input_tensor, scale=scale)
)
hidden_states = (
self.downsample(hidden_states, scale=scale)
)
hidden_states = self.conv1(hidden_states)
if self.time_emb_proj is not None:
if not self.skip_time_act:
temb = self.nonlinearity(temb)
temb = (
self.time_emb_proj(temb, scale)[:, :, None, None]
)
if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb
if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
hidden_states = self.norm2(hidden_states, temb)
else:
hidden_states = self.norm2(hidden_states)
if temb is not None and self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = (
self.conv_shortcut(input_tensor)
)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
return output_tensor
def get_down_block3d(
down_block_type: str,
num_layers: int,
in_channels: int,
out_channels: int,
temb_channels: int,
add_downsample: bool,
downsample_stride: int,
resnet_eps: float,
resnet_act_fn: str,
transformer_layers_per_block: int = 1,
num_attention_heads: Optional[int] = None,
resnet_groups: Optional[int] = None,
cross_attention_dim: Optional[int] = None,
downsample_padding: Optional[int] = None,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
only_cross_attention: bool = False,
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
attention_type: str = "default",
resnet_skip_time_act: bool = False,
resnet_out_scale_factor: float = 1.0,
cross_attention_norm: Optional[str] = None,
attention_head_dim: Optional[int] = None,
downsample_type: Optional[str] = None,
dropout: float = 0.0,
):
# If attn head dim is not defined, we default it to the number of heads
if attention_head_dim is None:
logger.warn(
f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
)
attention_head_dim = num_attention_heads
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
if down_block_type == "DownEncoderBlockCausal3D":
return DownEncoderBlockCausal3D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
dropout=dropout,
add_downsample=add_downsample,
downsample_stride=downsample_stride,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
resnet_time_scale_shift=resnet_time_scale_shift,
)
raise ValueError(f"{down_block_type} does not exist.")
def get_up_block3d(
up_block_type: str,
num_layers: int,
in_channels: int,
out_channels: int,
prev_output_channel: int,
temb_channels: int,
add_upsample: bool,
upsample_scale_factor: Tuple,
resnet_eps: float,
resnet_act_fn: str,
resolution_idx: Optional[int] = None,
transformer_layers_per_block: int = 1,
num_attention_heads: Optional[int] = None,
resnet_groups: Optional[int] = None,
cross_attention_dim: Optional[int] = None,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
only_cross_attention: bool = False,
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
attention_type: str = "default",
resnet_skip_time_act: bool = False,
resnet_out_scale_factor: float = 1.0,
cross_attention_norm: Optional[str] = None,
attention_head_dim: Optional[int] = None,
upsample_type: Optional[str] = None,
dropout: float = 0.0,
) -> nn.Module:
# If attn head dim is not defined, we default it to the number of heads
if attention_head_dim is None:
logger.warn(
f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
)
attention_head_dim = num_attention_heads
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
if up_block_type == "UpDecoderBlockCausal3D":
return UpDecoderBlockCausal3D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
resolution_idx=resolution_idx,
dropout=dropout,
add_upsample=add_upsample,
upsample_scale_factor=upsample_scale_factor,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
temb_channels=temb_channels,
)
raise ValueError(f"{up_block_type} does not exist.")
class UNetMidBlockCausal3D(nn.Module):
"""
A 3D UNet mid-block [`UNetMidBlockCausal3D`] with multiple residual blocks and optional attention blocks.
"""
def __init__(
self,
in_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default", # default, spatial
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
attn_groups: Optional[int] = None,
resnet_pre_norm: bool = True,
add_attention: bool = True,
attention_head_dim: int = 1,
output_scale_factor: float = 1.0,
):
super().__init__()
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
self.add_attention = add_attention
if attn_groups is None:
attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None
# there is always at least one resnet
resnets = [
ResnetBlockCausal3D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
]
attentions = []
if attention_head_dim is None:
logger.warn(
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
)
attention_head_dim = in_channels
for _ in range(num_layers):
if self.add_attention:
attentions.append(
Attention(
in_channels,
heads=in_channels // attention_head_dim,
dim_head=attention_head_dim,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
norm_num_groups=attn_groups,
spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
residual_connection=True,
bias=True,
upcast_softmax=True,
_from_deprecated_attn_block=True,
)
)
else:
attentions.append(None)
resnets.append(
ResnetBlockCausal3D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None:
B, C, T, H, W = hidden_states.shape
hidden_states = rearrange(hidden_states, "b c f h w -> b (f h w) c")
attention_mask = prepare_causal_attention_mask(
T, H * W, hidden_states.dtype, hidden_states.device, batch_size=B
)
hidden_states = attn(hidden_states, temb=temb, attention_mask=attention_mask)
hidden_states = rearrange(hidden_states, "b (f h w) c -> b c f h w", f=T, h=H, w=W)
hidden_states = resnet(hidden_states, temb)
return hidden_states
class DownEncoderBlockCausal3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
output_scale_factor: float = 1.0,
add_downsample: bool = True,
downsample_stride: int = 2,
downsample_padding: int = 1,
):
super().__init__()
resnets = []
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlockCausal3D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=None,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.resnets = nn.ModuleList(resnets)
if add_downsample:
self.downsamplers = nn.ModuleList(
[
DownsampleCausal3D(
out_channels,
use_conv=True,
out_channels=out_channels,
padding=downsample_padding,
name="op",
stride=downsample_stride,
)
]
)
else:
self.downsamplers = None
def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb=None, scale=scale)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states, scale)
return hidden_states
class UpDecoderBlockCausal3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
resolution_idx: Optional[int] = None,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default", # default, spatial
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
output_scale_factor: float = 1.0,
add_upsample: bool = True,
upsample_scale_factor=(2, 2, 2),
temb_channels: Optional[int] = None,
):
super().__init__()
resnets = []
for i in range(num_layers):
input_channels = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlockCausal3D(
in_channels=input_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.resnets = nn.ModuleList(resnets)
if add_upsample:
self.upsamplers = nn.ModuleList(
[
UpsampleCausal3D(
out_channels,
use_conv=True,
out_channels=out_channels,
upsample_factor=upsample_scale_factor,
)
]
)
else:
self.upsamplers = None
self.resolution_idx = resolution_idx
def forward(
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
) -> torch.FloatTensor:
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb=temb, scale=scale)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
return hidden_states
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#
# Modified from diffusers==0.29.2
#
# ==============================================================================
from typing import Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
from einops import rearrange
from diffusers.utils import logging
from diffusers.models.activations import get_activation
from diffusers.models.attention_processor import SpatialNorm
from diffusers.models.attention_processor import Attention
from diffusers.models.normalization import AdaGroupNorm
from diffusers.models.normalization import RMSNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def prepare_causal_attention_mask_ori(n_frame: int, n_hw: int, dtype, device, batch_size: int = None):
seq_len = n_frame * n_hw
mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device)
for i in range(seq_len):
i_frame = i // n_hw
mask[i, : (i_frame + 1) * n_hw] = 0
if batch_size is not None:
mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
return mask
def prepare_causal_attention_mask(n_frame: int, n_hw: int, dtype, device, batch_size: int = None):
seq_len = n_frame * n_hw
mask = torch.full((n_frame, n_frame, n_hw, n_hw), float("-inf"), dtype=dtype, device=device)
# mask = mask.reshape(n_frame, n_frame, n_hw, n_hw)
idx_arr = torch.tril(torch.ones(n_frame, n_frame, dtype=dtype, device=device))
idx_arr = idx_arr > torch.zeros_like(idx_arr)
for i in range(n_frame):
for j in range(n_frame):
if idx_arr[i, j]:
mask[i, j] = torch.zeros(n_hw, n_hw, dtype=dtype, device=device)
# mask[idx_arr] = torch.zeros(n_hw, n_hw, dtype=dtype, device=device)
mask = mask.view(n_frame, -1, n_hw).transpose(1, 0).reshape(seq_len, -1).transpose(1, 0)
if batch_size is not None:
mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
return mask.to(device)
class CausalConv3d(nn.Module):
"""
Implements a causal 3D convolution layer where each position only depends on previous timesteps and current spatial locations.
This maintains temporal causality in video generation tasks.
"""
def __init__(
self,
chan_in,
chan_out,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Union[int, Tuple[int, int, int]] = 1,
dilation: Union[int, Tuple[int, int, int]] = 1,
pad_mode="replicate",
**kwargs,
):
super().__init__()
self.pad_mode = pad_mode
padding = (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size - 1, 0) # W, H, T
self.time_causal_padding = padding
self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
def forward(self, x):
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
return self.conv(x)
class UpsampleCausal3D(nn.Module):
"""
A 3D upsampling layer with an optional convolution.
"""
def __init__(
self,
channels: int,
use_conv: bool = False,
use_conv_transpose: bool = False,
out_channels: Optional[int] = None,
name: str = "conv",
kernel_size: Optional[int] = None,
padding=1,
norm_type=None,
eps=None,
elementwise_affine=None,
bias=True,
interpolate=True,
upsample_factor=(2, 2, 2),
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_conv_transpose = use_conv_transpose
self.name = name
self.interpolate = interpolate
self.upsample_factor = upsample_factor
if norm_type == "ln_norm":
self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
elif norm_type == "rms_norm":
self.norm = RMSNorm(channels, eps, elementwise_affine)
elif norm_type is None:
self.norm = None
else:
raise ValueError(f"unknown norm_type: {norm_type}")
conv = None
if use_conv_transpose:
raise NotImplementedError
elif use_conv:
if kernel_size is None:
kernel_size = 3
conv = CausalConv3d(self.channels, self.out_channels, kernel_size=kernel_size, bias=bias)
if name == "conv":
self.conv = conv
else:
self.Conv2d_0 = conv
def forward(
self,
hidden_states: torch.FloatTensor,
output_size: Optional[int] = None,
scale: float = 1.0,
) -> torch.FloatTensor:
assert hidden_states.shape[1] == self.channels
if self.norm is not None:
raise NotImplementedError
if self.use_conv_transpose:
return self.conv(hidden_states)
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
dtype = hidden_states.dtype
if dtype == torch.bfloat16:
hidden_states = hidden_states.to(torch.float32)
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
if hidden_states.shape[0] >= 64:
hidden_states = hidden_states.contiguous()
# if `output_size` is passed we force the interpolation output
# size and do not make use of `scale_factor=2`
if self.interpolate:
B, C, T, H, W = hidden_states.shape
first_h, other_h = hidden_states.split((1, T - 1), dim=2)
if output_size is None:
if T > 1:
other_h = F.interpolate(other_h, scale_factor=self.upsample_factor, mode="nearest")
# first_h = first_h.squeeze(2)
first_h = first_h.view(B, C, H, W)
first_h = F.interpolate(first_h, scale_factor=self.upsample_factor[1:], mode="nearest")
first_h = first_h.unsqueeze(2)
else:
raise NotImplementedError
if T > 1:
hidden_states = torch.cat((first_h, other_h), dim=2)
else:
hidden_states = first_h
# If the input is bfloat16, we cast back to bfloat16
if dtype == torch.bfloat16:
hidden_states = hidden_states.to(dtype)
if self.use_conv:
if self.name == "conv":
hidden_states = self.conv(hidden_states)
else:
hidden_states = self.Conv2d_0(hidden_states)
return hidden_states
class DownsampleCausal3D(nn.Module):
"""
A 3D downsampling layer with an optional convolution.
"""
def __init__(
self,
channels: int,
use_conv: bool = False,
out_channels: Optional[int] = None,
padding: int = 1,
name: str = "conv",
kernel_size=3,
norm_type=None,
eps=None,
elementwise_affine=None,
bias=True,
stride=2,
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.padding = padding
stride = stride
self.name = name
if norm_type == "ln_norm":
self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
elif norm_type == "rms_norm":
self.norm = RMSNorm(channels, eps, elementwise_affine)
elif norm_type is None:
self.norm = None
else:
raise ValueError(f"unknown norm_type: {norm_type}")
if use_conv:
conv = CausalConv3d(self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, bias=bias)
else:
raise NotImplementedError
if name == "conv":
self.Conv2d_0 = conv
self.conv = conv
elif name == "Conv2d_0":
self.conv = conv
else:
self.conv = conv
def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
assert hidden_states.shape[1] == self.channels
if self.norm is not None:
hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
assert hidden_states.shape[1] == self.channels
hidden_states = self.conv(hidden_states)
return hidden_states
class ResnetBlockCausal3D(nn.Module):
r"""
A Resnet block.
"""
def __init__(
self,
*,
in_channels: int,
out_channels: Optional[int] = None,
conv_shortcut: bool = False,
dropout: float = 0.0,
temb_channels: int = 512,
groups: int = 32,
groups_out: Optional[int] = None,
pre_norm: bool = True,
eps: float = 1e-6,
non_linearity: str = "swish",
skip_time_act: bool = False,
# default, scale_shift, ada_group, spatial
time_embedding_norm: str = "default",
kernel: Optional[torch.FloatTensor] = None,
output_scale_factor: float = 1.0,
use_in_shortcut: Optional[bool] = None,
up: bool = False,
down: bool = False,
conv_shortcut_bias: bool = True,
conv_3d_out_channels: Optional[int] = None,
):
super().__init__()
self.pre_norm = pre_norm
self.pre_norm = True
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.up = up
self.down = down
self.output_scale_factor = output_scale_factor
self.time_embedding_norm = time_embedding_norm
self.skip_time_act = skip_time_act
linear_cls = nn.Linear
if groups_out is None:
groups_out = groups
if self.time_embedding_norm == "ada_group":
self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
elif self.time_embedding_norm == "spatial":
self.norm1 = SpatialNorm(in_channels, temb_channels)
else:
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, stride=1)
if temb_channels is not None:
if self.time_embedding_norm == "default":
self.time_emb_proj = linear_cls(temb_channels, out_channels)
elif self.time_embedding_norm == "scale_shift":
self.time_emb_proj = linear_cls(temb_channels, 2 * out_channels)
elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
self.time_emb_proj = None
else:
raise ValueError(f"Unknown time_embedding_norm : {self.time_embedding_norm} ")
else:
self.time_emb_proj = None
if self.time_embedding_norm == "ada_group":
self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
elif self.time_embedding_norm == "spatial":
self.norm2 = SpatialNorm(out_channels, temb_channels)
else:
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
self.dropout = torch.nn.Dropout(dropout)
conv_3d_out_channels = conv_3d_out_channels or out_channels
self.conv2 = CausalConv3d(out_channels, conv_3d_out_channels, kernel_size=3, stride=1)
self.nonlinearity = get_activation(non_linearity)
self.upsample = self.downsample = None
if self.up:
self.upsample = UpsampleCausal3D(in_channels, use_conv=False)
elif self.down:
self.downsample = DownsampleCausal3D(in_channels, use_conv=False, name="op")
self.use_in_shortcut = self.in_channels != conv_3d_out_channels if use_in_shortcut is None else use_in_shortcut
self.conv_shortcut = None
if self.use_in_shortcut:
self.conv_shortcut = CausalConv3d(
in_channels,
conv_3d_out_channels,
kernel_size=1,
stride=1,
bias=conv_shortcut_bias,
)
def forward(
self,
input_tensor: torch.FloatTensor,
temb: torch.FloatTensor,
scale: float = 1.0,
) -> torch.FloatTensor:
hidden_states = input_tensor
if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
hidden_states = self.norm1(hidden_states, temb)
else:
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
if self.upsample is not None:
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
if hidden_states.shape[0] >= 64:
input_tensor = input_tensor.contiguous()
hidden_states = hidden_states.contiguous()
input_tensor = self.upsample(input_tensor, scale=scale)
hidden_states = self.upsample(hidden_states, scale=scale)
elif self.downsample is not None:
input_tensor = self.downsample(input_tensor, scale=scale)
hidden_states = self.downsample(hidden_states, scale=scale)
hidden_states = self.conv1(hidden_states)
if self.time_emb_proj is not None:
if not self.skip_time_act:
temb = self.nonlinearity(temb)
temb = self.time_emb_proj(temb, scale)[:, :, None, None]
if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb
if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
hidden_states = self.norm2(hidden_states, temb)
else:
hidden_states = self.norm2(hidden_states)
if temb is not None and self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
return output_tensor
def get_down_block3d(
down_block_type: str,
num_layers: int,
in_channels: int,
out_channels: int,
temb_channels: int,
add_downsample: bool,
downsample_stride: int,
resnet_eps: float,
resnet_act_fn: str,
transformer_layers_per_block: int = 1,
num_attention_heads: Optional[int] = None,
resnet_groups: Optional[int] = None,
cross_attention_dim: Optional[int] = None,
downsample_padding: Optional[int] = None,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
only_cross_attention: bool = False,
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
attention_type: str = "default",
resnet_skip_time_act: bool = False,
resnet_out_scale_factor: float = 1.0,
cross_attention_norm: Optional[str] = None,
attention_head_dim: Optional[int] = None,
downsample_type: Optional[str] = None,
dropout: float = 0.0,
):
# If attn head dim is not defined, we default it to the number of heads
if attention_head_dim is None:
logger.warn(f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}.")
attention_head_dim = num_attention_heads
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
if down_block_type == "DownEncoderBlockCausal3D":
return DownEncoderBlockCausal3D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
dropout=dropout,
add_downsample=add_downsample,
downsample_stride=downsample_stride,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
resnet_time_scale_shift=resnet_time_scale_shift,
)
raise ValueError(f"{down_block_type} does not exist.")
def get_up_block3d(
up_block_type: str,
num_layers: int,
in_channels: int,
out_channels: int,
prev_output_channel: int,
temb_channels: int,
add_upsample: bool,
upsample_scale_factor: Tuple,
resnet_eps: float,
resnet_act_fn: str,
resolution_idx: Optional[int] = None,
transformer_layers_per_block: int = 1,
num_attention_heads: Optional[int] = None,
resnet_groups: Optional[int] = None,
cross_attention_dim: Optional[int] = None,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
only_cross_attention: bool = False,
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
attention_type: str = "default",
resnet_skip_time_act: bool = False,
resnet_out_scale_factor: float = 1.0,
cross_attention_norm: Optional[str] = None,
attention_head_dim: Optional[int] = None,
upsample_type: Optional[str] = None,
dropout: float = 0.0,
) -> nn.Module:
# If attn head dim is not defined, we default it to the number of heads
if attention_head_dim is None:
logger.warn(f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}.")
attention_head_dim = num_attention_heads
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
if up_block_type == "UpDecoderBlockCausal3D":
return UpDecoderBlockCausal3D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
resolution_idx=resolution_idx,
dropout=dropout,
add_upsample=add_upsample,
upsample_scale_factor=upsample_scale_factor,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
temb_channels=temb_channels,
)
raise ValueError(f"{up_block_type} does not exist.")
class UNetMidBlockCausal3D(nn.Module):
"""
A 3D UNet mid-block [`UNetMidBlockCausal3D`] with multiple residual blocks and optional attention blocks.
"""
def __init__(
self,
in_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default", # default, spatial
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
attn_groups: Optional[int] = None,
resnet_pre_norm: bool = True,
add_attention: bool = True,
attention_head_dim: int = 1,
output_scale_factor: float = 1.0,
):
super().__init__()
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
self.add_attention = add_attention
if attn_groups is None:
attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None
# there is always at least one resnet
resnets = [
ResnetBlockCausal3D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
]
attentions = []
if attention_head_dim is None:
logger.warn(f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}.")
attention_head_dim = in_channels
for _ in range(num_layers):
if self.add_attention:
attentions.append(
Attention(
in_channels,
heads=in_channels // attention_head_dim,
dim_head=attention_head_dim,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
norm_num_groups=attn_groups,
spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
residual_connection=True,
bias=True,
upcast_softmax=True,
_from_deprecated_attn_block=True,
)
)
else:
attentions.append(None)
resnets.append(
ResnetBlockCausal3D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None:
B, C, T, H, W = hidden_states.shape
hidden_states = rearrange(hidden_states, "b c f h w -> b (f h w) c")
attention_mask = prepare_causal_attention_mask(T, H * W, hidden_states.dtype, hidden_states.device, batch_size=B)
hidden_states = attn(hidden_states, temb=temb, attention_mask=attention_mask)
hidden_states = rearrange(hidden_states, "b (f h w) c -> b c f h w", f=T, h=H, w=W)
hidden_states = resnet(hidden_states, temb)
return hidden_states
class DownEncoderBlockCausal3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
output_scale_factor: float = 1.0,
add_downsample: bool = True,
downsample_stride: int = 2,
downsample_padding: int = 1,
):
super().__init__()
resnets = []
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlockCausal3D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=None,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.resnets = nn.ModuleList(resnets)
if add_downsample:
self.downsamplers = nn.ModuleList(
[
DownsampleCausal3D(
out_channels,
use_conv=True,
out_channels=out_channels,
padding=downsample_padding,
name="op",
stride=downsample_stride,
)
]
)
else:
self.downsamplers = None
def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb=None, scale=scale)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states, scale)
return hidden_states
class UpDecoderBlockCausal3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
resolution_idx: Optional[int] = None,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default", # default, spatial
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
output_scale_factor: float = 1.0,
add_upsample: bool = True,
upsample_scale_factor=(2, 2, 2),
temb_channels: Optional[int] = None,
):
super().__init__()
resnets = []
for i in range(num_layers):
input_channels = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlockCausal3D(
in_channels=input_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.resnets = nn.ModuleList(resnets)
if add_upsample:
self.upsamplers = nn.ModuleList(
[
UpsampleCausal3D(
out_channels,
use_conv=True,
out_channels=out_channels,
upsample_factor=upsample_scale_factor,
)
]
)
else:
self.upsamplers = None
self.resolution_idx = resolution_idx
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0) -> torch.FloatTensor:
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb=temb, scale=scale)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
return hidden_states
from dataclasses import dataclass
from typing import Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
from diffusers.utils import BaseOutput, is_torch_version
from diffusers.utils.torch_utils import randn_tensor
from diffusers.models.attention_processor import SpatialNorm
from .unet_causal_3d_blocks import (
CausalConv3d,
UNetMidBlockCausal3D,
get_down_block3d,
get_up_block3d,
)
@dataclass
class DecoderOutput(BaseOutput):
r"""
Output of decoding method.
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
The decoded output sample from the last layer of the model.
"""
sample: torch.FloatTensor
class EncoderCausal3D(nn.Module):
r"""
The `EncoderCausal3D` layer of a variational autoencoder that encodes its input into a latent representation.
"""
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str, ...] = ("DownEncoderBlockCausal3D",),
block_out_channels: Tuple[int, ...] = (64,),
layers_per_block: int = 2,
norm_num_groups: int = 32,
act_fn: str = "silu",
double_z: bool = True,
mid_block_add_attention=True,
time_compression_ratio: int = 4,
spatial_compression_ratio: int = 8,
):
super().__init__()
self.layers_per_block = layers_per_block
self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1)
self.mid_block = None
self.down_blocks = nn.ModuleList([])
# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio))
num_time_downsample_layers = int(np.log2(time_compression_ratio))
if time_compression_ratio == 4:
add_spatial_downsample = bool(i < num_spatial_downsample_layers)
add_time_downsample = bool(
i >= (len(block_out_channels) - 1 - num_time_downsample_layers)
and not is_final_block
)
else:
raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.")
downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1)
downsample_stride_T = (2,) if add_time_downsample else (1,)
downsample_stride = tuple(downsample_stride_T + downsample_stride_HW)
down_block = get_down_block3d(
down_block_type,
num_layers=self.layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
add_downsample=bool(add_spatial_downsample or add_time_downsample),
downsample_stride=downsample_stride,
resnet_eps=1e-6,
downsample_padding=0,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attention_head_dim=output_channel,
temb_channels=None,
)
self.down_blocks.append(down_block)
# mid
self.mid_block = UNetMidBlockCausal3D(
in_channels=block_out_channels[-1],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
output_scale_factor=1,
resnet_time_scale_shift="default",
attention_head_dim=block_out_channels[-1],
resnet_groups=norm_num_groups,
temb_channels=None,
add_attention=mid_block_add_attention,
)
# out
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU()
conv_out_channels = 2 * out_channels if double_z else out_channels
self.conv_out = CausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3)
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
r"""The forward method of the `EncoderCausal3D` class."""
assert len(sample.shape) == 5, "The input tensor should have 5 dimensions"
sample = self.conv_in(sample)
# down
for down_block in self.down_blocks:
sample = down_block(sample)
# middle
sample = self.mid_block(sample)
# post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
return sample
class DecoderCausal3D(nn.Module):
r"""
The `DecoderCausal3D` layer of a variational autoencoder that decodes its latent representation into an output sample.
"""
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
up_block_types: Tuple[str, ...] = ("UpDecoderBlockCausal3D",),
block_out_channels: Tuple[int, ...] = (64,),
layers_per_block: int = 2,
norm_num_groups: int = 32,
act_fn: str = "silu",
norm_type: str = "group", # group, spatial
mid_block_add_attention=True,
time_compression_ratio: int = 4,
spatial_compression_ratio: int = 8,
):
super().__init__()
self.layers_per_block = layers_per_block
self.conv_in = CausalConv3d(in_channels, block_out_channels[-1], kernel_size=3, stride=1)
self.mid_block = None
self.up_blocks = nn.ModuleList([])
temb_channels = in_channels if norm_type == "spatial" else None
# mid
self.mid_block = UNetMidBlockCausal3D(
in_channels=block_out_channels[-1],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
output_scale_factor=1,
resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
attention_head_dim=block_out_channels[-1],
resnet_groups=norm_num_groups,
temb_channels=temb_channels,
add_attention=mid_block_add_attention,
)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio))
num_time_upsample_layers = int(np.log2(time_compression_ratio))
if time_compression_ratio == 4:
add_spatial_upsample = bool(i < num_spatial_upsample_layers)
add_time_upsample = bool(
i >= len(block_out_channels) - 1 - num_time_upsample_layers
and not is_final_block
)
else:
raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.")
upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1)
upsample_scale_factor_T = (2,) if add_time_upsample else (1,)
upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW)
up_block = get_up_block3d(
up_block_type,
num_layers=self.layers_per_block + 1,
in_channels=prev_output_channel,
out_channels=output_channel,
prev_output_channel=None,
add_upsample=bool(add_spatial_upsample or add_time_upsample),
upsample_scale_factor=upsample_scale_factor,
resnet_eps=1e-6,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attention_head_dim=output_channel,
temb_channels=temb_channels,
resnet_time_scale_shift=norm_type,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
if norm_type == "spatial":
self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
else:
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU()
self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3)
self.gradient_checkpointing = False
def forward(
self,
sample: torch.FloatTensor,
latent_embeds: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
r"""The forward method of the `DecoderCausal3D` class."""
assert len(sample.shape) == 5, "The input tensor should have 5 dimensions."
sample = self.conv_in(sample)
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block),
sample,
latent_embeds,
use_reentrant=False,
)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block),
sample,
latent_embeds,
use_reentrant=False,
)
else:
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), sample, latent_embeds
)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
else:
# middle
sample = self.mid_block(sample, latent_embeds)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = up_block(sample, latent_embeds)
# post-process
if latent_embeds is None:
sample = self.conv_norm_out(sample)
else:
sample = self.conv_norm_out(sample, latent_embeds)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
return sample
class DiagonalGaussianDistribution(object):
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
if parameters.ndim == 3:
dim = 2 # (B, L, C)
elif parameters.ndim == 5 or parameters.ndim == 4:
dim = 1 # (B, C, T, H ,W) / (B, C, H, W)
else:
raise NotImplementedError
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=dim)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(
self.mean, device=self.parameters.device, dtype=self.parameters.dtype
)
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
# make sure sample is on the same device as the parameters and has same dtype
sample = randn_tensor(
self.mean.shape,
generator=generator,
device=self.parameters.device,
dtype=self.parameters.dtype,
)
x = self.mean + self.std * sample
return x
def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
if self.deterministic:
return torch.Tensor([0.0])
else:
reduce_dim = list(range(1, self.mean.ndim))
if other is None:
return 0.5 * torch.sum(
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
dim=reduce_dim,
)
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var
- 1.0
- self.logvar
+ other.logvar,
dim=reduce_dim,
)
def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
if self.deterministic:
return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar +
torch.pow(sample - self.mean, 2) / self.var,
dim=dims,
)
def mode(self) -> torch.Tensor:
return self.mean
from dataclasses import dataclass
from typing import Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
from diffusers.utils import BaseOutput, is_torch_version
from diffusers.utils.torch_utils import randn_tensor
from diffusers.models.attention_processor import SpatialNorm
from .unet_causal_3d_blocks import (
CausalConv3d,
UNetMidBlockCausal3D,
get_down_block3d,
get_up_block3d,
)
@dataclass
class DecoderOutput(BaseOutput):
r"""
Output of decoding method.
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
The decoded output sample from the last layer of the model.
"""
sample: torch.FloatTensor
class EncoderCausal3D(nn.Module):
r"""
The `EncoderCausal3D` layer of a variational autoencoder that encodes its input into a latent representation.
"""
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str, ...] = ("DownEncoderBlockCausal3D",),
block_out_channels: Tuple[int, ...] = (64,),
layers_per_block: int = 2,
norm_num_groups: int = 32,
act_fn: str = "silu",
double_z: bool = True,
mid_block_add_attention=True,
time_compression_ratio: int = 4,
spatial_compression_ratio: int = 8,
):
super().__init__()
self.layers_per_block = layers_per_block
self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1)
self.mid_block = None
self.down_blocks = nn.ModuleList([])
# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio))
num_time_downsample_layers = int(np.log2(time_compression_ratio))
if time_compression_ratio == 4:
add_spatial_downsample = bool(i < num_spatial_downsample_layers)
add_time_downsample = bool(i >= (len(block_out_channels) - 1 - num_time_downsample_layers) and not is_final_block)
else:
raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.")
downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1)
downsample_stride_T = (2,) if add_time_downsample else (1,)
downsample_stride = tuple(downsample_stride_T + downsample_stride_HW)
down_block = get_down_block3d(
down_block_type,
num_layers=self.layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
add_downsample=bool(add_spatial_downsample or add_time_downsample),
downsample_stride=downsample_stride,
resnet_eps=1e-6,
downsample_padding=0,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attention_head_dim=output_channel,
temb_channels=None,
)
self.down_blocks.append(down_block)
# mid
self.mid_block = UNetMidBlockCausal3D(
in_channels=block_out_channels[-1],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
output_scale_factor=1,
resnet_time_scale_shift="default",
attention_head_dim=block_out_channels[-1],
resnet_groups=norm_num_groups,
temb_channels=None,
add_attention=mid_block_add_attention,
)
# out
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU()
conv_out_channels = 2 * out_channels if double_z else out_channels
self.conv_out = CausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3)
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
r"""The forward method of the `EncoderCausal3D` class."""
assert len(sample.shape) == 5, "The input tensor should have 5 dimensions"
sample = self.conv_in(sample)
# down
for down_block in self.down_blocks:
sample = down_block(sample)
# middle
sample = self.mid_block(sample)
# post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
return sample
class DecoderCausal3D(nn.Module):
r"""
The `DecoderCausal3D` layer of a variational autoencoder that decodes its latent representation into an output sample.
"""
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
up_block_types: Tuple[str, ...] = ("UpDecoderBlockCausal3D",),
block_out_channels: Tuple[int, ...] = (64,),
layers_per_block: int = 2,
norm_num_groups: int = 32,
act_fn: str = "silu",
norm_type: str = "group", # group, spatial
mid_block_add_attention=True,
time_compression_ratio: int = 4,
spatial_compression_ratio: int = 8,
):
super().__init__()
self.layers_per_block = layers_per_block
self.conv_in = CausalConv3d(in_channels, block_out_channels[-1], kernel_size=3, stride=1)
self.mid_block = None
self.up_blocks = nn.ModuleList([])
temb_channels = in_channels if norm_type == "spatial" else None
# mid
self.mid_block = UNetMidBlockCausal3D(
in_channels=block_out_channels[-1],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
output_scale_factor=1,
resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
attention_head_dim=block_out_channels[-1],
resnet_groups=norm_num_groups,
temb_channels=temb_channels,
add_attention=mid_block_add_attention,
)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio))
num_time_upsample_layers = int(np.log2(time_compression_ratio))
if time_compression_ratio == 4:
add_spatial_upsample = bool(i < num_spatial_upsample_layers)
add_time_upsample = bool(i >= len(block_out_channels) - 1 - num_time_upsample_layers and not is_final_block)
else:
raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.")
upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1)
upsample_scale_factor_T = (2,) if add_time_upsample else (1,)
upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW)
up_block = get_up_block3d(
up_block_type,
num_layers=self.layers_per_block + 1,
in_channels=prev_output_channel,
out_channels=output_channel,
prev_output_channel=None,
add_upsample=bool(add_spatial_upsample or add_time_upsample),
upsample_scale_factor=upsample_scale_factor,
resnet_eps=1e-6,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attention_head_dim=output_channel,
temb_channels=temb_channels,
resnet_time_scale_shift=norm_type,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
if norm_type == "spatial":
self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
else:
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU()
self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3)
self.gradient_checkpointing = False
def forward(
self,
sample: torch.FloatTensor,
latent_embeds: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
r"""The forward method of the `DecoderCausal3D` class."""
assert len(sample.shape) == 5, "The input tensor should have 5 dimensions."
sample = self.conv_in(sample)
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block),
sample,
latent_embeds,
use_reentrant=False,
)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block),
sample,
latent_embeds,
use_reentrant=False,
)
else:
# middle
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample, latent_embeds)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
else:
# middle
sample = self.mid_block(sample, latent_embeds)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = up_block(sample, latent_embeds)
# post-process
if latent_embeds is None:
sample = self.conv_norm_out(sample)
else:
sample = self.conv_norm_out(sample, latent_embeds)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
return sample
class DiagonalGaussianDistribution(object):
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
if parameters.ndim == 3:
dim = 2 # (B, L, C)
elif parameters.ndim == 5 or parameters.ndim == 4:
dim = 1 # (B, C, T, H ,W) / (B, C, H, W)
else:
raise NotImplementedError
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=dim)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device, dtype=self.parameters.dtype)
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
# make sure sample is on the same device as the parameters and has same dtype
sample = randn_tensor(
self.mean.shape,
generator=generator,
device=self.parameters.device,
dtype=self.parameters.dtype,
)
x = self.mean + self.std * sample
return x
def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
if self.deterministic:
return torch.Tensor([0.0])
else:
reduce_dim = list(range(1, self.mean.ndim))
if other is None:
return 0.5 * torch.sum(
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
dim=reduce_dim,
)
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=reduce_dim,
)
def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
if self.deterministic:
return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims,
)
def mode(self) -> torch.Tensor:
return self.mean
......@@ -44,7 +44,6 @@ class CausalConv3d(nn.Conv3d):
class RMS_norm(nn.Module):
def __init__(self, dim, channel_first=True, images=True, bias=False):
super().__init__()
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
......@@ -56,16 +55,10 @@ class RMS_norm(nn.Module):
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
def forward(self, x):
return (
F.normalize(x, dim=(1 if self.channel_first else -1))
* self.scale
* self.gamma
+ self.bias
)
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
class Upsample(nn.Upsample):
def forward(self, x):
"""
Fix bfloat16 support for nearest neighbor interpolation.
......@@ -74,7 +67,6 @@ class Upsample(nn.Upsample):
class Resample(nn.Module):
def __init__(self, dim, mode):
assert mode in (
"none",
......@@ -101,16 +93,10 @@ class Resample(nn.Module):
self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
elif mode == "downsample2d":
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))
)
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
elif mode == "downsample3d":
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))
)
self.time_conv = CausalConv3d(
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)
)
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
else:
self.resample = nn.Identity()
......@@ -124,28 +110,17 @@ class Resample(nn.Module):
feat_cache[idx] = "Rep"
feat_idx[0] += 1
else:
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if (
cache_x.shape[2] < 2
and feat_cache[idx] is not None
and feat_cache[idx] != "Rep"
):
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
# cache last frame of last two chunk
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :]
.unsqueeze(2)
.to(cache_x.device),
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
cache_x,
],
dim=2,
)
if (
cache_x.shape[2] < 2
and feat_cache[idx] is not None
and feat_cache[idx] == "Rep"
):
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep":
cache_x = torch.cat(
[torch.zeros_like(cache_x).to(cache_x.device), cache_x],
dim=2,
......@@ -172,15 +147,12 @@ class Resample(nn.Module):
feat_cache[idx] = x.clone()
feat_idx[0] += 1
else:
cache_x = x[:, :, -1:, :, :].clone()
# if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
# # cache last frame of last two chunk
# cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.time_conv(
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)
)
x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
feat_cache[idx] = cache_x
feat_idx[0] += 1
return x
......@@ -210,7 +182,6 @@ class Resample(nn.Module):
class ResidualBlock(nn.Module):
def __init__(self, in_dim, out_dim, dropout=0.0):
super().__init__()
self.in_dim = in_dim
......@@ -226,9 +197,7 @@ class ResidualBlock(nn.Module):
nn.Dropout(dropout),
CausalConv3d(out_dim, out_dim, 3, padding=1),
)
self.shortcut = (
CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
)
self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
h = self.shortcut(x)
......@@ -240,9 +209,7 @@ class ResidualBlock(nn.Module):
# cache last frame of last two chunk
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :]
.unsqueeze(2)
.to(cache_x.device),
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
cache_x,
],
dim=2,
......@@ -278,13 +245,7 @@ class AttentionBlock(nn.Module):
x = rearrange(x, "b c t h w -> (b t) c h w")
x = self.norm(x)
# compute query, key, value
q, k, v = (
self.to_qkv(x)
.reshape(b * t, 1, c * 3, -1)
.permute(0, 1, 3, 2)
.contiguous()
.chunk(3, dim=-1)
)
q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(0, 1, 3, 2).contiguous().chunk(3, dim=-1)
# apply attention
x = F.scaled_dot_product_attention(
......@@ -301,7 +262,6 @@ class AttentionBlock(nn.Module):
class Encoder3d(nn.Module):
def __init__(
self,
dim=128,
......@@ -400,9 +360,7 @@ class Encoder3d(nn.Module):
# cache last frame of last two chunk
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :]
.unsqueeze(2)
.to(cache_x.device),
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
cache_x,
],
dim=2,
......@@ -416,7 +374,6 @@ class Encoder3d(nn.Module):
class Decoder3d(nn.Module):
def __init__(
self,
dim=128,
......@@ -518,9 +475,7 @@ class Decoder3d(nn.Module):
# cache last frame of last two chunk
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :]
.unsqueeze(2)
.to(cache_x.device),
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
cache_x,
],
dim=2,
......@@ -542,7 +497,6 @@ def count_conv3d(model):
class WanVAE_(nn.Module):
def __init__(
self,
dim=128,
......@@ -613,9 +567,7 @@ class WanVAE_(nn.Module):
out = torch.cat([out, out_], 2)
mu, log_var = self.conv1(out).chunk(2, dim=1)
if isinstance(scale[0], torch.Tensor):
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
1, self.z_dim, 1, 1, 1
)
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1)
else:
mu = (mu - scale[0]) * scale[1]
self.clear_cache()
......@@ -625,9 +577,7 @@ class WanVAE_(nn.Module):
self.clear_cache()
# z: [b,c,t,h,w]
if isinstance(scale[0], torch.Tensor):
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
1, self.z_dim, 1, 1, 1
)
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1)
else:
z = z / scale[1] + scale[0]
iter_ = z.shape[2]
......@@ -700,7 +650,6 @@ def _video_vae(pretrained_path=None, z_dim=None, device="cpu", **kwargs):
class WanVAE:
def __init__(
self,
z_dim=16,
......@@ -780,11 +729,8 @@ class WanVAE:
"""
videos: A list of videos each with shape [C, T, H, W].
"""
return [
self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0)
for u in videos
]
return [self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0) for u in videos]
def decode_dist(self, zs, world_size, cur_rank, split_dim):
splited_total_len = zs.shape[split_dim]
splited_chunk_len = splited_total_len // world_size
......@@ -792,37 +738,37 @@ class WanVAE:
if cur_rank == 0:
if split_dim == 2:
zs = zs[:,:,:splited_chunk_len+2*padding_size,:].contiguous()
zs = zs[:, :, : splited_chunk_len + 2 * padding_size, :].contiguous()
elif split_dim == 3:
zs = zs[:,:,:,:splited_chunk_len+2*padding_size].contiguous()
elif cur_rank == world_size-1:
zs = zs[:, :, :, : splited_chunk_len + 2 * padding_size].contiguous()
elif cur_rank == world_size - 1:
if split_dim == 2:
zs = zs[:,:,-(splited_chunk_len+2*padding_size):,:].contiguous()
zs = zs[:, :, -(splited_chunk_len + 2 * padding_size) :, :].contiguous()
elif split_dim == 3:
zs = zs[:,:,:,-(splited_chunk_len+2*padding_size):].contiguous()
zs = zs[:, :, :, -(splited_chunk_len + 2 * padding_size) :].contiguous()
else:
if split_dim == 2:
zs = zs[:,:,cur_rank*splited_chunk_len-padding_size:(cur_rank+1)*splited_chunk_len+padding_size,:].contiguous()
zs = zs[:, :, cur_rank * splited_chunk_len - padding_size : (cur_rank + 1) * splited_chunk_len + padding_size, :].contiguous()
elif split_dim == 3:
zs = zs[:,:,:,cur_rank*splited_chunk_len-padding_size:(cur_rank+1)*splited_chunk_len+padding_size].contiguous()
zs = zs[:, :, :, cur_rank * splited_chunk_len - padding_size : (cur_rank + 1) * splited_chunk_len + padding_size].contiguous()
images = self.model.decode(zs.unsqueeze(0), self.scale).float().clamp_(-1, 1)
if cur_rank == 0:
if split_dim == 2:
images = images[:,:,:,:splited_chunk_len*8,:].contiguous()
images = images[:, :, :, : splited_chunk_len * 8, :].contiguous()
elif split_dim == 3:
images = images[:,:,:,:,:splited_chunk_len*8].contiguous()
elif cur_rank == world_size-1:
images = images[:, :, :, :, : splited_chunk_len * 8].contiguous()
elif cur_rank == world_size - 1:
if split_dim == 2:
images = images[:,:,:,-splited_chunk_len*8:,:].contiguous()
images = images[:, :, :, -splited_chunk_len * 8 :, :].contiguous()
elif split_dim == 3:
images = images[:,:,:,:,-splited_chunk_len*8:].contiguous()
images = images[:, :, :, :, -splited_chunk_len * 8 :].contiguous()
else:
if split_dim == 2:
images = images[:,:,:,8*padding_size:-8*padding_size,:].contiguous()
images = images[:, :, :, 8 * padding_size : -8 * padding_size, :].contiguous()
elif split_dim == 3:
images = images[:,:,:,:,8*padding_size:-8*padding_size].contiguous()
images = images[:, :, :, :, 8 * padding_size : -8 * padding_size].contiguous()
full_images = [torch.empty_like(images) for _ in range(world_size)]
dist.all_gather(full_images, images)
......@@ -832,7 +778,6 @@ class WanVAE:
images = torch.cat(full_images, dim=-1)
return images
def decode(self, zs, generator, args):
if args.cpu_offload:
......
......@@ -5,7 +5,7 @@ from lightx2v.text2v.models.video_encoders.hf.autoencoder_kl_causal_3d.autoencod
from lightx2v.text2v.models.video_encoders.trt.autoencoder_kl_causal_3d import trt_vae_infer
class VideoEncoderKLCausal3DModel():
class VideoEncoderKLCausal3DModel:
def __init__(self, model_path, dtype, device):
self.model_path = model_path
self.dtype = dtype
......@@ -13,10 +13,10 @@ class VideoEncoderKLCausal3DModel():
self.load()
def load(self):
self.vae_path = os.path.join(self.model_path, 'hunyuan-video-t2v-720p/vae')
self.vae_path = os.path.join(self.model_path, "hunyuan-video-t2v-720p/vae")
config = AutoencoderKLCausal3D.load_config(self.vae_path)
self.model = AutoencoderKLCausal3D.from_config(config)
ckpt = torch.load(os.path.join(self.vae_path, 'pytorch_model.pt'), map_location='cpu', weights_only=True)
ckpt = torch.load(os.path.join(self.vae_path, "pytorch_model.pt"), map_location="cpu", weights_only=True)
self.model.load_state_dict(ckpt)
self.model = self.model.to(dtype=self.dtype, device=self.device)
self.model.requires_grad_(False)
......@@ -28,12 +28,11 @@ class VideoEncoderKLCausal3DModel():
latents = latents / self.model.config.scaling_factor
latents = latents.to(dtype=self.dtype, device=self.device)
self.model.enable_tiling()
image = self.model.decode(
latents, return_dict=False, generator=generator
)[0]
image = self.model.decode(latents, return_dict=False, generator=generator)[0]
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().float()
return image
if __name__ == "__main__":
vae_model = VideoEncoderKLCausal3DModel("/mnt/nvme1/yongyang/models/hy/ckpts", dtype=torch.float16, device=torch.device("cuda"))
......@@ -100,18 +100,18 @@ class HyVaeTrtModelInfer(nn.Module):
device = batch.device
dtype = batch.dtype
batch = batch.cpu().numpy()
def get_output_shape(shp):
b, c, t, h, w = shp
out = (b, 3, 4*(t-1)+1, h*8, w*8)
out = (b, 3, 4 * (t - 1) + 1, h * 8, w * 8)
return out
shp_dict = {"inp": batch.shape, "out": get_output_shape(batch.shape)}
self.alloc(shp_dict)
output = np.zeros(*self.output_spec())
# Process I/O and execute the network
common.memcpy_host_to_device(
self.inputs[0]["allocation"], np.ascontiguousarray(batch)
)
common.memcpy_host_to_device(self.inputs[0]["allocation"], np.ascontiguousarray(batch))
self.context.execute_v2(self.allocations)
common.memcpy_device_to_host(output, self.outputs[0]["allocation"])
output = torch.from_numpy(output).to(device).type(dtype)
......@@ -122,19 +122,16 @@ class HyVaeTrtModelInfer(nn.Module):
logger.info("Start to do VAE onnx exporting.")
device = next(decoder.parameters())[0].device
example_inp = torch.rand(1, 16, 17, 32, 32).to(device).type(next(decoder.parameters())[0].dtype)
out_path = str(Path(str(model_dir))/"vae_decoder.onnx")
out_path = str(Path(str(model_dir)) / "vae_decoder.onnx")
torch.onnx.export(
decoder.eval().half(),
example_inp.half(),
out_path,
input_names=['inp'],
output_names=['out'],
input_names=["inp"],
output_names=["out"],
opset_version=14,
dynamic_axes={
"inp": {1: "c1", 2: "c2", 3: "c3", 4: "c4"},
"out": {1: "c1", 2: "c2", 3: "c3", 4: "c4"}
}
)
dynamic_axes={"inp": {1: "c1", 2: "c2", 3: "c3", 4: "c4"}, "out": {1: "c1", 2: "c2", 3: "c3", 4: "c4"}},
)
# onnx_ori = onnx.load(out_path)
os.system(f"onnxsim {out_path} {out_path}")
# onnx_opt, check = simplify(onnx_ori)
......@@ -163,4 +160,4 @@ class HyVaeTrtModelInfer(nn.Module):
if not Path(engine_path).exists():
raise RuntimeError(f"Convert vae onnx({onnx_path}) to tensorrt engine failed.")
logger.info("Finish VAE tensorrt converting.")
return engine_path
\ No newline at end of file
return engine_path
......@@ -8,20 +8,20 @@ class BaseQuantizer(object):
self.sym = symmetric
self.granularity = granularity
self.kwargs = kwargs
if self.granularity == 'per_group':
self.group_size = self.kwargs['group_size']
self.calib_algo = self.kwargs.get('calib_algo', 'minmax')
if self.granularity == "per_group":
self.group_size = self.kwargs["group_size"]
self.calib_algo = self.kwargs.get("calib_algo", "minmax")
def get_tensor_range(self, tensor):
if self.calib_algo == 'minmax':
if self.calib_algo == "minmax":
return self.get_minmax_range(tensor)
elif self.calib_algo == 'mse':
elif self.calib_algo == "mse":
return self.get_mse_range(tensor)
else:
raise ValueError(f'Unsupported calibration algorithm: {self.calib_algo}')
raise ValueError(f"Unsupported calibration algorithm: {self.calib_algo}")
def get_minmax_range(self, tensor):
if self.granularity == 'per_tensor':
if self.granularity == "per_tensor":
max_val = torch.max(tensor)
min_val = torch.min(tensor)
else:
......@@ -47,7 +47,7 @@ class BaseQuantizer(object):
return scales, zeros, qmax, qmin
def reshape_tensor(self, tensor, allow_padding=False):
if self.granularity == 'per_group':
if self.granularity == "per_group":
t = tensor.reshape(-1, self.group_size)
else:
t = tensor
......@@ -79,7 +79,7 @@ class BaseQuantizer(object):
tensor, scales, zeros, qmax, qmin = self.get_tensor_qparams(tensor)
tensor = self.quant(tensor, scales, zeros, qmax, qmin)
tensor = self.restore_tensor(tensor, org_shape)
if self.sym == True:
if self.sym:
zeros = None
return tensor, scales, zeros
......@@ -87,9 +87,9 @@ class BaseQuantizer(object):
class IntegerQuantizer(BaseQuantizer):
def __init__(self, bit, symmetric, granularity, **kwargs):
super().__init__(bit, symmetric, granularity, **kwargs)
if 'int_range' in self.kwargs:
self.qmin = self.kwargs['int_range'][0]
self.qmax = self.kwargs['int_range'][1]
if "int_range" in self.kwargs:
self.qmin = self.kwargs["int_range"][0]
self.qmax = self.kwargs["int_range"][1]
else:
if self.sym:
self.qmin = -(2 ** (self.bit - 1))
......@@ -110,7 +110,14 @@ class IntegerQuantizer(BaseQuantizer):
tensor = (tensor - zeros) * scales
return tensor
def quant_dequant(self, tensor, scales, zeros, qmax, qmin,):
def quant_dequant(
self,
tensor,
scales,
zeros,
qmax,
qmin,
):
tensor = self.quant(tensor, scales, zeros, qmax, qmin)
tensor = self.dequant(tensor, scales, zeros)
return tensor
......@@ -119,19 +126,19 @@ class IntegerQuantizer(BaseQuantizer):
class FloatQuantizer(BaseQuantizer):
def __init__(self, bit, symmetric, granularity, **kwargs):
super().__init__(bit, symmetric, granularity, **kwargs)
assert self.bit in ['e4m3', 'e5m2'], f'Unsupported bit configuration: {self.bit}'
assert self.sym == True
if self.bit == 'e4m3':
assert self.bit in ["e4m3", "e5m2"], f"Unsupported bit configuration: {self.bit}"
assert self.sym
if self.bit == "e4m3":
self.e_bits = 4
self.m_bits = 3
self.fp_dtype = torch.float8_e4m3fn
elif self.bit == 'e5m2':
elif self.bit == "e5m2":
self.e_bits = 5
self.m_bits = 2
self.fp_dtype = torch.float8_e5m2
else:
raise ValueError(f'Unsupported bit configuration: {self.bit}')
raise ValueError(f"Unsupported bit configuration: {self.bit}")
finfo = torch.finfo(self.fp_dtype)
self.qmin, self.qmax = finfo.min, finfo.max
......@@ -141,13 +148,9 @@ class FloatQuantizer(BaseQuantizer):
def quant(self, tensor, scales, zeros, qmax, qmin):
scaled_tensor = tensor / scales + zeros
scaled_tensor = torch.clip(
scaled_tensor, self.qmin.cuda(), self.qmax.cuda()
)
scaled_tensor = torch.clip(scaled_tensor, self.qmin.cuda(), self.qmax.cuda())
org_dtype = scaled_tensor.dtype
q_tensor = float_quantize(
scaled_tensor.float(), self.e_bits, self.m_bits, rounding='nearest'
)
q_tensor = float_quantize(scaled_tensor.float(), self.e_bits, self.m_bits, rounding="nearest")
q_tensor.to(org_dtype)
return q_tensor
......@@ -161,21 +164,21 @@ class FloatQuantizer(BaseQuantizer):
return tensor
if __name__ == '__main__':
if __name__ == "__main__":
weight = torch.randn(4096, 4096, dtype=torch.bfloat16).cuda()
quantizer = IntegerQuantizer(4, False, 'per_group', group_size=128)
quantizer = IntegerQuantizer(4, False, "per_group", group_size=128)
q_weight = quantizer.fake_quant_tensor(weight)
print(weight)
print(q_weight)
print(f"cosine = {torch.cosine_similarity(weight.view(1, -1).to(torch.float64), q_weight.view(1, -1).to(torch.float64))}")
realq_weight, scales, zeros = quantizer.real_quant_tensor(weight)
print(f"realq_weight = {realq_weight}, {realq_weight.shape}")
print(f"scales = {scales}, {scales.shape}")
print(f"zeros = {zeros}, {zeros.shape}")
weight = torch.randn(8192, 4096, dtype=torch.bfloat16).cuda()
quantizer = FloatQuantizer('e4m3', True, 'per_channel')
quantizer = FloatQuantizer("e4m3", True, "per_channel")
q_weight = quantizer.fake_quant_tensor(weight)
print(weight)
print(q_weight)
......
......@@ -11,14 +11,14 @@ class Register(dict):
def register(self, target, key=None):
if not callable(target):
raise Exception(f'Error: {target} must be callable!')
raise Exception(f"Error: {target} must be callable!")
if key is None:
key = target.__name__
if key in self._dict:
raise Exception(f'{key} already exists.')
raise Exception(f"{key} already exists.")
self[key] = target
return target
......
......@@ -66,12 +66,7 @@ def cache_video(
# preprocess
tensor = tensor.clamp(min(value_range), max(value_range))
tensor = torch.stack(
[
torchvision.utils.make_grid(
u, nrow=nrow, normalize=normalize, value_range=value_range
)
for u in tensor.unbind(2)
],
[torchvision.utils.make_grid(u, nrow=nrow, normalize=normalize, value_range=value_range) for u in tensor.unbind(2)],
dim=1,
).permute(1, 2, 3, 0)
tensor = (tensor * 255).type(torch.uint8).cpu()
......
[tool.ruff]
exclude = [".git", ".mypy_cache", ".ruff_cache", ".venv", "dist"]
target-version = "py311"
line-length = 200
indent-width = 4
lint.ignore =["F"]
[tool.ruff.format]
line-ending = "lf"
quote-style = "double"
indent-style = "space"
......@@ -18,4 +18,4 @@ python ${lightx2v_path}/lightx2v/__main__.py \
--target_width 1280 \
--attention_type flash_attn3 \
--save_video_path ./output_lightx2v_int8.mp4 \
--mm_config '{"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm", "weight_auto_quant": true}'
--mm_config '{"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm", "weight_auto_quant": true}'
......@@ -18,4 +18,4 @@ torchrun --nproc_per_node=4 ${lightx2v_path}/lightx2v/__main__.py \
--target_width 1280 \
--attention_type flash_attn2 \
--mm_config '{"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm", "weight_auto_quant": true}' \
--parallel_attn
\ No newline at end of file
--parallel_attn
......@@ -20,4 +20,4 @@ python ${lightx2v_path}/lightx2v/__main__.py \
--cpu_offload \
--feature_caching TaylorSeer \
--save_video_path ./output_lightx2v_offload_TaylorSeer.mp4 \
# --mm_config '{"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm", "weight_auto_quant": true}'
\ No newline at end of file
# --mm_config '{"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm", "weight_auto_quant": true}'
......@@ -28,4 +28,4 @@ python ${lightx2v_path}/lightx2v/__main__.py \
--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \
--mm_config '{"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm", "weight_auto_quant": true}' \
# --feature_caching Tea \
# --use_ret_steps \
\ No newline at end of file
# --use_ret_steps \
......@@ -28,4 +28,4 @@ python ${lightx2v_path}/lightx2v/__main__.py \
# --mm_config '{"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F", "weight_auto_quant": true}' \
# --feature_caching Tea \
# --use_ret_steps \
# --teacache_thresh 0.2
\ No newline at end of file
# --teacache_thresh 0.2
File mode changed from 100644 to 100755
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment