Commit b96ae489 authored by mashun1's avatar mashun1
Browse files

magic-animate

parents
Pipeline #674 canceled with stages
This diff is collapsed.
# *************************************************************************
# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
# ytedance Inc..
# *************************************************************************
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn.functional as F
from torch import nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
from diffusers.utils import BaseOutput
from diffusers.utils.import_utils import is_xformers_available
from diffusers.models.attention import FeedForward, AdaLayerNorm
from diffusers.models.attention import Attention as CrossAttention
from einops import rearrange, repeat
@dataclass
class Transformer3DModelOutput(BaseOutput):
sample: torch.FloatTensor
if is_xformers_available():
import xformers
import xformers.ops
else:
xformers = None
class Transformer3DModel(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
norm_num_groups: int = 32,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = False,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
use_linear_projection: bool = False,
only_cross_attention: bool = False,
upcast_attention: bool = False,
unet_use_cross_frame_attention=None,
unet_use_temporal_attention=None,
):
super().__init__()
self.use_linear_projection = use_linear_projection
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
# Define input layers
self.in_channels = in_channels
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
if use_linear_projection:
self.proj_in = nn.Linear(in_channels, inner_dim)
else:
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
# Define transformers blocks
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
attention_bias=attention_bias,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
unet_use_temporal_attention=unet_use_temporal_attention,
)
for d in range(num_layers)
]
)
# 4. Define output layers
if use_linear_projection:
self.proj_out = nn.Linear(in_channels, inner_dim)
else:
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
# Input
assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
video_length = hidden_states.shape[2]
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
# JH: need not repeat when a list of prompts are given
if encoder_hidden_states.shape[0] != hidden_states.shape[0]:
encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
batch, channel, height, weight = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
if not self.use_linear_projection:
hidden_states = self.proj_in(hidden_states)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
else:
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
hidden_states = self.proj_in(hidden_states)
# Blocks
for block in self.transformer_blocks:
hidden_states = block(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
timestep=timestep,
video_length=video_length
)
# Output
if not self.use_linear_projection:
hidden_states = (
hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
)
hidden_states = self.proj_out(hidden_states)
else:
hidden_states = self.proj_out(hidden_states)
hidden_states = (
hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
)
output = hidden_states + residual
output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
if not return_dict:
return (output,)
return Transformer3DModelOutput(sample=output)
class BasicTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
upcast_attention: bool = False,
unet_use_cross_frame_attention = None,
unet_use_temporal_attention = None,
):
super().__init__()
self.only_cross_attention = only_cross_attention
self.use_ada_layer_norm = num_embeds_ada_norm is not None
self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
self.unet_use_temporal_attention = unet_use_temporal_attention
# SC-Attn
assert unet_use_cross_frame_attention is not None
if unet_use_cross_frame_attention:
self.attn1 = SparseCausalAttention2D(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
upcast_attention=upcast_attention,
)
else:
self.attn1 = CrossAttention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
)
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
# Cross-Attn
if cross_attention_dim is not None:
self.attn2 = CrossAttention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
)
else:
self.attn2 = None
if cross_attention_dim is not None:
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
else:
self.norm2 = None
# Feed-forward
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
self.norm3 = nn.LayerNorm(dim)
self.use_ada_layer_norm_zero = False
# Temp-Attn
assert unet_use_temporal_attention is not None
if unet_use_temporal_attention:
self.attn_temp = CrossAttention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
)
nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, *args, **kwargs):
if not is_xformers_available():
print("Here is how to install it")
raise ModuleNotFoundError(
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
" xformers",
name="xformers",
)
elif not torch.cuda.is_available():
raise ValueError(
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
" available for GPU "
)
else:
try:
# Make sure we can run the memory efficient attention
_ = xformers.ops.memory_efficient_attention(
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
)
except Exception as e:
raise e
self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
if self.attn2 is not None:
self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
# self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None):
# SparseCausal-Attention
norm_hidden_states = (
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
)
# if self.only_cross_attention:
# hidden_states = (
# self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
# )
# else:
# hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
# pdb.set_trace()
if self.unet_use_cross_frame_attention:
hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
else:
hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
if self.attn2 is not None:
# Cross-Attention
norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
)
hidden_states = (
self.attn2(
norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
)
+ hidden_states
)
# Feed-forward
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
# Temporal-Attention
if self.unet_use_temporal_attention:
d = hidden_states.shape[1]
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
norm_hidden_states = (
self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
)
hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
return hidden_states
This diff is collapsed.
# *************************************************************************
# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
# ytedance Inc..
# *************************************************************************
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional
import numpy as np
import torch
from torch import nn
def get_timestep_embedding(
timesteps: torch.Tensor,
embedding_dim: int,
flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
embeddings. :return: an [N x dim] Tensor of positional embeddings.
"""
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange(
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
# scale embeddings
emb = scale * emb
# concat sine and cosine embeddings
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
# flip sine and cosine embeddings
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
# zero pad
if embedding_dim % 2 == 1:
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
"""
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be divisible by 2")
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
"""
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be divisible by 2")
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
class PatchEmbed(nn.Module):
"""2D Image to Patch Embedding"""
def __init__(
self,
height=224,
width=224,
patch_size=16,
in_channels=3,
embed_dim=768,
layer_norm=False,
flatten=True,
bias=True,
):
super().__init__()
num_patches = (height // patch_size) * (width // patch_size)
self.flatten = flatten
self.layer_norm = layer_norm
self.proj = nn.Conv2d(
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
)
if layer_norm:
self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
else:
self.norm = None
pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5))
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
def forward(self, latent):
latent = self.proj(latent)
if self.flatten:
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
if self.layer_norm:
latent = self.norm(latent)
return latent + self.pos_embed
class TimestepEmbedding(nn.Module):
def __init__(
self,
in_channels: int,
time_embed_dim: int,
act_fn: str = "silu",
out_dim: int = None,
post_act_fn: Optional[str] = None,
cond_proj_dim=None,
):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
if cond_proj_dim is not None:
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
else:
self.cond_proj = None
if act_fn == "silu":
self.act = nn.SiLU()
elif act_fn == "mish":
self.act = nn.Mish()
elif act_fn == "gelu":
self.act = nn.GELU()
else:
raise ValueError(f"{act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'")
if out_dim is not None:
time_embed_dim_out = out_dim
else:
time_embed_dim_out = time_embed_dim
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
if post_act_fn is None:
self.post_act = None
elif post_act_fn == "silu":
self.post_act = nn.SiLU()
elif post_act_fn == "mish":
self.post_act = nn.Mish()
elif post_act_fn == "gelu":
self.post_act = nn.GELU()
else:
raise ValueError(f"{post_act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'")
def forward(self, sample, condition=None):
if condition is not None:
sample = sample + self.cond_proj(condition)
sample = self.linear_1(sample)
if self.act is not None:
sample = self.act(sample)
sample = self.linear_2(sample)
if self.post_act is not None:
sample = self.post_act(sample)
return sample
class Timesteps(nn.Module):
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
def forward(self, timesteps):
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
)
return t_emb
class GaussianFourierProjection(nn.Module):
"""Gaussian Fourier embeddings for noise levels."""
def __init__(
self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False
):
super().__init__()
self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
self.log = log
self.flip_sin_to_cos = flip_sin_to_cos
if set_W_to_weight:
# to delete later
self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
self.weight = self.W
def forward(self, x):
if self.log:
x = torch.log(x)
x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
if self.flip_sin_to_cos:
out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
else:
out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
return out
class ImagePositionalEmbeddings(nn.Module):
"""
Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
height and width of the latent space.
For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092
For VQ-diffusion:
Output vector embeddings are used as input for the transformer.
Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE.
Args:
num_embed (`int`):
Number of embeddings for the latent pixels embeddings.
height (`int`):
Height of the latent image i.e. the number of height embeddings.
width (`int`):
Width of the latent image i.e. the number of width embeddings.
embed_dim (`int`):
Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings.
"""
def __init__(
self,
num_embed: int,
height: int,
width: int,
embed_dim: int,
):
super().__init__()
self.height = height
self.width = width
self.num_embed = num_embed
self.embed_dim = embed_dim
self.emb = nn.Embedding(self.num_embed, embed_dim)
self.height_emb = nn.Embedding(self.height, embed_dim)
self.width_emb = nn.Embedding(self.width, embed_dim)
def forward(self, index):
emb = self.emb(index)
height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height))
# 1 x H x D -> 1 x H x 1 x D
height_emb = height_emb.unsqueeze(2)
width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width))
# 1 x W x D -> 1 x 1 x W x D
width_emb = width_emb.unsqueeze(1)
pos_emb = height_emb + width_emb
# 1 x H x W x D -> 1 x L xD
pos_emb = pos_emb.view(1, self.height * self.width, -1)
emb = emb + pos_emb[:, : emb.shape[1], :]
return emb
class LabelEmbedding(nn.Module):
"""
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
Args:
num_classes (`int`): The number of classes.
hidden_size (`int`): The size of the vector embeddings.
dropout_prob (`float`): The probability of dropping a label.
"""
def __init__(self, num_classes, hidden_size, dropout_prob):
super().__init__()
use_cfg_embedding = dropout_prob > 0
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
self.num_classes = num_classes
self.dropout_prob = dropout_prob
def token_drop(self, labels, force_drop_ids=None):
"""
Drops labels to enable classifier-free guidance.
"""
if force_drop_ids is None:
drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
else:
drop_ids = torch.tensor(force_drop_ids == 1)
labels = torch.where(drop_ids, self.num_classes, labels)
return labels
def forward(self, labels, force_drop_ids=None):
use_dropout = self.dropout_prob > 0
if (self.training and use_dropout) or (force_drop_ids is not None):
labels = self.token_drop(labels, force_drop_ids)
embeddings = self.embedding_table(labels)
return embeddings
class CombinedTimestepLabelEmbeddings(nn.Module):
def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob)
def forward(self, timestep, class_labels, hidden_dtype=None):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
class_labels = self.class_embedder(class_labels) # (N, D)
conditioning = timesteps_emb + class_labels # (N, D)
return conditioning
\ No newline at end of file
# *************************************************************************
# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
# ytedance Inc..
# *************************************************************************
# Adapted from https://github.com/guoyww/AnimateDiff
from dataclasses import dataclass
import torch
import torch.nn.functional as F
from torch import nn
from diffusers.utils import BaseOutput
from diffusers.utils.import_utils import is_xformers_available
from diffusers.models.attention import FeedForward
from magicanimate.models.orig_attention import CrossAttention
from einops import rearrange, repeat
import math
def zero_module(module):
# Zero out the parameters of a module and return it.
for p in module.parameters():
p.detach().zero_()
return module
@dataclass
class TemporalTransformer3DModelOutput(BaseOutput):
sample: torch.FloatTensor
if is_xformers_available():
import xformers
import xformers.ops
else:
xformers = None
def get_motion_module(
in_channels,
motion_module_type: str,
motion_module_kwargs: dict
):
if motion_module_type == "Vanilla":
return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs,)
else:
raise ValueError
class VanillaTemporalModule(nn.Module):
def __init__(
self,
in_channels,
num_attention_heads = 8,
num_transformer_block = 2,
attention_block_types =( "Temporal_Self", "Temporal_Self" ),
cross_frame_attention_mode = None,
temporal_position_encoding = False,
temporal_position_encoding_max_len = 24,
temporal_attention_dim_div = 1,
zero_initialize = True,
):
super().__init__()
self.temporal_transformer = TemporalTransformer3DModel(
in_channels=in_channels,
num_attention_heads=num_attention_heads,
attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div,
num_layers=num_transformer_block,
attention_block_types=attention_block_types,
cross_frame_attention_mode=cross_frame_attention_mode,
temporal_position_encoding=temporal_position_encoding,
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
)
if zero_initialize:
self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None, anchor_frame_idx=None):
hidden_states = input_tensor
hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)
output = hidden_states
return output
class TemporalTransformer3DModel(nn.Module):
def __init__(
self,
in_channels,
num_attention_heads,
attention_head_dim,
num_layers,
attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
dropout = 0.0,
norm_num_groups = 32,
cross_attention_dim = 768,
activation_fn = "geglu",
attention_bias = False,
upcast_attention = False,
cross_frame_attention_mode = None,
temporal_position_encoding = False,
temporal_position_encoding_max_len = 24,
):
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
self.proj_in = nn.Linear(in_channels, inner_dim)
self.transformer_blocks = nn.ModuleList(
[
TemporalTransformerBlock(
dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
attention_block_types=attention_block_types,
dropout=dropout,
norm_num_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
attention_bias=attention_bias,
upcast_attention=upcast_attention,
cross_frame_attention_mode=cross_frame_attention_mode,
temporal_position_encoding=temporal_position_encoding,
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
)
for d in range(num_layers)
]
)
self.proj_out = nn.Linear(inner_dim, in_channels)
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
video_length = hidden_states.shape[2]
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
batch, channel, height, weight = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
hidden_states = self.proj_in(hidden_states)
# Transformer Blocks
for block in self.transformer_blocks:
hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length)
# output
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
output = hidden_states + residual
output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
return output
class TemporalTransformerBlock(nn.Module):
def __init__(
self,
dim,
num_attention_heads,
attention_head_dim,
attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
dropout = 0.0,
norm_num_groups = 32,
cross_attention_dim = 768,
activation_fn = "geglu",
attention_bias = False,
upcast_attention = False,
cross_frame_attention_mode = None,
temporal_position_encoding = False,
temporal_position_encoding_max_len = 24,
):
super().__init__()
attention_blocks = []
norms = []
for block_name in attention_block_types:
attention_blocks.append(
VersatileAttention(
attention_mode=block_name.split("_")[0],
cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None,
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
cross_frame_attention_mode=cross_frame_attention_mode,
temporal_position_encoding=temporal_position_encoding,
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
)
)
norms.append(nn.LayerNorm(dim))
self.attention_blocks = nn.ModuleList(attention_blocks)
self.norms = nn.ModuleList(norms)
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
self.ff_norm = nn.LayerNorm(dim)
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
for attention_block, norm in zip(self.attention_blocks, self.norms):
norm_hidden_states = norm(hidden_states)
hidden_states = attention_block(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None,
video_length=video_length,
) + hidden_states
hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
output = hidden_states
return output
class PositionalEncoding(nn.Module):
def __init__(
self,
d_model,
dropout = 0.,
max_len = 24
):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(1, max_len, d_model)
pe[0, :, 0::2] = torch.sin(position * div_term)
pe[0, :, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)
class VersatileAttention(CrossAttention):
def __init__(
self,
attention_mode = None,
cross_frame_attention_mode = None,
temporal_position_encoding = False,
temporal_position_encoding_max_len = 24,
*args, **kwargs
):
super().__init__(*args, **kwargs)
assert attention_mode == "Temporal"
self.attention_mode = attention_mode
self.is_cross_attention = kwargs["cross_attention_dim"] is not None
self.pos_encoder = PositionalEncoding(
kwargs["query_dim"],
dropout=0.,
max_len=temporal_position_encoding_max_len
) if (temporal_position_encoding and attention_mode == "Temporal") else None
def extra_repr(self):
return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
batch_size, sequence_length, _ = hidden_states.shape
if self.attention_mode == "Temporal":
d = hidden_states.shape[1]
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
if self.pos_encoder is not None:
hidden_states = self.pos_encoder(hidden_states)
encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states
else:
raise NotImplementedError
encoder_hidden_states = encoder_hidden_states
if self.group_norm is not None:
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = self.to_q(hidden_states)
dim = query.shape[-1]
query = self.reshape_heads_to_batch_dim(query)
if self.added_kv_proj_dim is not None:
raise NotImplementedError
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)
if attention_mask is not None:
if attention_mask.shape[-1] != query.shape[1]:
target_length = query.shape[1]
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
# attention, what we cannot get enough of
if self._use_memory_efficient_attention_xformers:
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
hidden_states = hidden_states.to(query.dtype)
else:
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
hidden_states = self._attention(query, key, value, attention_mask)
else:
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
if self.attention_mode == "Temporal":
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
return hidden_states
This diff is collapsed.
This diff is collapsed.
# *************************************************************************
# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
# ytedance Inc..
# *************************************************************************
# Adapted from https://github.com/guoyww/AnimateDiff
# Copyright 2023 The HuggingFace Team. All rights reserved.
# `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
class InflatedConv3d(nn.Conv2d):
def forward(self, x):
video_length = x.shape[2]
x = rearrange(x, "b c f h w -> (b f) c h w")
x = super().forward(x)
x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
return x
class Upsample3D(nn.Module):
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_conv_transpose = use_conv_transpose
self.name = name
conv = None
if use_conv_transpose:
raise NotImplementedError
elif use_conv:
self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
def forward(self, hidden_states, output_size=None):
assert hidden_states.shape[1] == self.channels
if self.use_conv_transpose:
raise NotImplementedError
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
dtype = hidden_states.dtype
if dtype == torch.bfloat16:
hidden_states = hidden_states.to(torch.float32)
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
if hidden_states.shape[0] >= 64:
hidden_states = hidden_states.contiguous()
# if `output_size` is passed we force the interpolation output
# size and do not make use of `scale_factor=2`
if output_size is None:
hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
else:
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
# If the input is bfloat16, we cast back to bfloat16
if dtype == torch.bfloat16:
hidden_states = hidden_states.to(dtype)
hidden_states = self.conv(hidden_states)
return hidden_states
class Downsample3D(nn.Module):
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.padding = padding
stride = 2
self.name = name
if use_conv:
self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
else:
raise NotImplementedError
def forward(self, hidden_states):
assert hidden_states.shape[1] == self.channels
if self.use_conv and self.padding == 0:
raise NotImplementedError
assert hidden_states.shape[1] == self.channels
hidden_states = self.conv(hidden_states)
return hidden_states
class ResnetBlock3D(nn.Module):
def __init__(
self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout=0.0,
temb_channels=512,
groups=32,
groups_out=None,
pre_norm=True,
eps=1e-6,
non_linearity="swish",
time_embedding_norm="default",
output_scale_factor=1.0,
use_in_shortcut=None,
):
super().__init__()
self.pre_norm = pre_norm
self.pre_norm = True
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.time_embedding_norm = time_embedding_norm
self.output_scale_factor = output_scale_factor
if groups_out is None:
groups_out = groups
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if temb_channels is not None:
if self.time_embedding_norm == "default":
time_emb_proj_out_channels = out_channels
elif self.time_embedding_norm == "scale_shift":
time_emb_proj_out_channels = out_channels * 2
else:
raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
else:
self.time_emb_proj = None
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if non_linearity == "swish":
self.nonlinearity = lambda x: F.silu(x)
elif non_linearity == "mish":
self.nonlinearity = Mish()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
self.conv_shortcut = None
if self.use_in_shortcut:
self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, input_tensor, temb):
hidden_states = input_tensor
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
if temb is not None:
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb
hidden_states = self.norm2(hidden_states)
if temb is not None and self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
return output_tensor
class Mish(torch.nn.Module):
def forward(self, hidden_states):
return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
# *************************************************************************
# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
# ytedance Inc..
# *************************************************************************
# Adapted from https://github.com/s9roll7/animatediff-cli-prompt-travel/tree/main
import numpy as np
from typing import Callable, Optional, List
def ordered_halving(val):
bin_str = f"{val:064b}"
bin_flip = bin_str[::-1]
as_int = int(bin_flip, 2)
return as_int / (1 << 64)
def uniform(
step: int = ...,
num_steps: Optional[int] = None,
num_frames: int = ...,
context_size: Optional[int] = None,
context_stride: int = 3,
context_overlap: int = 4,
closed_loop: bool = True,
):
if num_frames <= context_size:
yield list(range(num_frames))
return
context_stride = min(context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1)
for context_step in 1 << np.arange(context_stride):
pad = int(round(num_frames * ordered_halving(step)))
for j in range(
int(ordered_halving(step) * context_step) + pad,
num_frames + pad + (0 if closed_loop else -context_overlap),
(context_size * context_step - context_overlap),
):
yield [e % num_frames for e in range(j, j + context_size * context_step, context_step)]
def get_context_scheduler(name: str) -> Callable:
if name == "uniform":
return uniform
else:
raise ValueError(f"Unknown context_overlap policy {name}")
def get_total_steps(
scheduler,
timesteps: List[int],
num_steps: Optional[int] = None,
num_frames: int = ...,
context_size: Optional[int] = None,
context_stride: int = 3,
context_overlap: int = 4,
closed_loop: bool = True,
):
return sum(
len(
list(
scheduler(
i,
num_steps,
num_frames,
context_size,
context_stride,
context_overlap,
)
)
)
for i in range(len(timesteps))
)
This diff is collapsed.
# Copyright 2023 ByteDance and/or its affiliates.
#
# Copyright (2023) MagicAnimate Authors
#
# ByteDance, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from ByteDance or
# its affiliates is strictly prohibited.
import os
import socket
import warnings
import torch
from torch import distributed as dist
def distributed_init(args):
if dist.is_initialized():
warnings.warn("Distributed is already initialized, cannot initialize twice!")
args.rank = dist.get_rank()
else:
print(
f"Distributed Init (Rank {args.rank}): "
f"{args.init_method}"
)
dist.init_process_group(
backend='nccl',
init_method=args.init_method,
world_size=args.world_size,
rank=args.rank,
)
print(
f"Initialized Host {socket.gethostname()} as Rank "
f"{args.rank}"
)
if "MASTER_ADDR" not in os.environ or "MASTER_PORT" not in os.environ:
# Set for onboxdataloader support
split = args.init_method.split("//")
assert len(split) == 2, (
"host url for distributed should be split by '//' "
+ "into exactly two elements"
)
split = split[1].split(":")
assert (
len(split) == 2
), "host url should be of the form <host_url>:<host_port>"
os.environ["MASTER_ADDR"] = split[0]
os.environ["MASTER_PORT"] = split[1]
# perform a dummy all-reduce to initialize the NCCL communicator
dist.all_reduce(torch.zeros(1).cuda())
suppress_output(is_master())
args.rank = dist.get_rank()
return args.rank
def get_rank():
if not dist.is_available():
return 0
if not dist.is_nccl_available():
return 0
if not dist.is_initialized():
return 0
return dist.get_rank()
def is_master():
return get_rank() == 0
def synchronize():
if dist.is_initialized():
dist.barrier()
def suppress_output(is_master):
"""Suppress printing on the current device. Force printing with `force=True`."""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop("force", False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
import warnings
builtin_warn = warnings.warn
def warn(*args, **kwargs):
force = kwargs.pop("force", False)
if is_master or force:
builtin_warn(*args, **kwargs)
# Log warnings only once
warnings.warn = warn
warnings.simplefilter("once", UserWarning)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment