"server/marlin/COPYRIGHT" did not exist on "f433f1f7705ba5d9110532a223d340effef059de"
Commit 1f5da520 authored by yangzhong's avatar yangzhong
Browse files

git init

parents
Pipeline #3144 failed with stages
in 0 seconds
# # This source code is licensed under the license found in the
# # LICENSE file in the root directory of this source tree.
# # --------------------------------------------------------
# # References:
# # PixArt: https://github.com/PixArt-alpha/PixArt-alpha
# # Latte: https://github.com/Vchitect/Latte
# # DiT: https://github.com/facebookresearch/DiT/tree/main
# # GLIDE: https://github.com/openai/glide-text2im
# # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
# # --------------------------------------------------------
# import math
# from typing import KeysView
# import numpy as np
# import torch
# import torch.distributed as dist
# import torch.nn as nn
# import torch.nn.functional as F
# import torch.utils.checkpoint
# import xformers.ops
# from einops import rearrange
# from timm.models.vision_transformer import Mlp
# from opensora.acceleration.communications import all_to_all, split_forward_gather_backward
# from opensora.acceleration.parallel_states import get_sequence_parallel_group
# import ipdb
# import cv2
# import os
# approx_gelu = lambda: nn.GELU(approximate="tanh")
# class LlamaRMSNorm(nn.Module):
# def __init__(self, hidden_size, eps=1e-6):
# """
# LlamaRMSNorm is equivalent to T5LayerNorm
# """
# super().__init__()
# self.weight = nn.Parameter(torch.ones(hidden_size))
# self.variance_epsilon = eps
# def forward(self, hidden_states):
# #ipdb.set_trace()
# input_dtype = hidden_states.dtype
# hidden_states = hidden_states.to(torch.float32)
# variance = hidden_states.pow(2).mean(-1, keepdim=True)
# hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
# #ipdb.set_trace()
# return self.weight * hidden_states.to(input_dtype)
# def get_layernorm(hidden_size: torch.Tensor, eps: float, affine: bool, use_kernel: bool):
# if use_kernel:
# try:
# from apex.normalization import FusedLayerNorm
# return FusedLayerNorm(hidden_size, elementwise_affine=affine, eps=eps)
# except ImportError:
# raise RuntimeError("FusedLayerNorm not available. Please install apex.")
# else:
# return nn.LayerNorm(hidden_size, eps, elementwise_affine=affine)
# def modulate(norm_func, x, shift, scale):
# # Suppose x is (B, N, D), shift is (B, D), scale is (B, D)
# dtype = x.dtype
# x = norm_func(x.to(torch.float32)).to(dtype)
# x = x * (scale.unsqueeze(1) + 1) + shift.unsqueeze(1)
# x = x.to(dtype)
# return x
# def t2i_modulate(x, shift, scale):
# return x * (1 + scale) + shift
# def get_attn_mask(query, key, idx):
# scale = 1.0 / query.shape[-1] ** 0.5
# query = query * scale
# query = query.transpose(0, 1)
# key = key.transpose(0, 1)
# attn = query @ key.transpose(-2, -1)
# attn = attn.softmax(-1) # H S L
# #attn = F.dropout(attn, p=0.0)
# # attn[attn>0.5]=1
# # attn[attn<=0.5]=0
# # attn[attn==1]=255
# attn = attn * 255
# H, S, L = attn.shape
# hight = 64
# width = 64
# for h in range(H):
# for l in range(L):
# map = attn[h, :, l]
# map = rearrange(map, '(H W) -> H W', H=hight, W=width)
# map = F.interpolate(map.unsqueeze(0).unsqueeze(0), size=[256,256],mode='nearest')
# np_array = map.squeeze(0).squeeze(0).detach().cpu().numpy()
# image = cv2.imwrite(os.path.join("/mnt/bn/yh-volume0/code/debug/code/OpenSora/outputs/vis", 'map'+str(idx)+'_head'+str(h)+'_word'+str(l)+'.jpg'), np_array)
# print("成功保存图像!")
# # ===============================================
# # General-purpose Layers
# # ===============================================
# class PatchEmbed3D(nn.Module):
# """Video to Patch Embedding.
# Args:
# patch_size (int): Patch token size. Default: (2,4,4).
# in_chans (int): Number of input video channels. Default: 3.
# embed_dim (int): Number of linear projection output channels. Default: 96.
# norm_layer (nn.Module, optional): Normalization layer. Default: None
# """
# def __init__(
# self,
# patch_size=(2, 4, 4),
# in_chans=3,
# embed_dim=96,
# norm_layer=None,
# flatten=True,
# ):
# super().__init__()
# self.patch_size = patch_size
# self.flatten = flatten
# self.in_chans = in_chans
# self.embed_dim = embed_dim
# self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
# if norm_layer is not None:
# self.norm = norm_layer(embed_dim)
# else:
# self.norm = None
# def forward(self, x):
# """Forward function."""
# # padding
# _, _, D, H, W = x.size()
# if W % self.patch_size[2] != 0:
# x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))
# if H % self.patch_size[1] != 0:
# x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))
# if D % self.patch_size[0] != 0:
# x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0]))
# x = self.proj(x) # (B C T H W)
# if self.norm is not None:
# D, Wh, Ww = x.size(2), x.size(3), x.size(4)
# x = x.flatten(2).transpose(1, 2)
# x = self.norm(x)
# x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)
# if self.flatten:
# x = x.flatten(2).transpose(1, 2) # BCTHW -> BNC
# return x
# class Attention(nn.Module):
# def __init__(
# self,
# dim: int,
# num_heads: int = 8,
# qkv_bias: bool = False,
# qk_norm: bool = False,
# attn_drop: float = 0.0,
# proj_drop: float = 0.0,
# norm_layer: nn.Module = nn.LayerNorm,
# enable_flashattn: bool = False,
# ) -> None:
# super().__init__()
# assert dim % num_heads == 0, "dim should be divisible by num_heads"
# self.dim = dim
# self.num_heads = num_heads
# self.head_dim = dim // num_heads
# self.scale = self.head_dim**-0.5
# self.enable_flashattn = enable_flashattn
# self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
# self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
# self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
# self.attn_drop = nn.Dropout(attn_drop)
# self.proj = nn.Linear(dim, dim)
# self.proj_drop = nn.Dropout(proj_drop)
# def forward(self, x: torch.Tensor) -> torch.Tensor:
# B, N, C = x.shape
# qkv = self.qkv(x)
# qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
# if self.enable_flashattn: # here
# qkv_permute_shape = (2, 0, 1, 3, 4)
# else:
# qkv_permute_shape = (2, 0, 3, 1, 4)
# qkv = qkv.view(qkv_shape).permute(qkv_permute_shape)
# q, k, v = qkv.unbind(0)
# q, k = self.q_norm(q), self.k_norm(k)
# if self.enable_flashattn:
# from flash_attn import flash_attn_func
# x = flash_attn_func(
# q,
# k,
# v,
# dropout_p=self.attn_drop.p if self.training else 0.0,
# softmax_scale=self.scale,
# )
# else:
# dtype = q.dtype
# q = q * self.scale
# attn = q @ k.transpose(-2, -1) # translate attn to float32
# attn = attn.to(torch.float32)
# attn = attn.softmax(dim=-1)
# attn = attn.to(dtype) # cast back attn to original dtype
# attn = self.attn_drop(attn)
# x = attn @ v
# x_output_shape = (B, N, C)
# if not self.enable_flashattn:
# x = x.transpose(1, 2)
# x = x.reshape(x_output_shape)
# x = self.proj(x)
# x = self.proj_drop(x)
# return x
# class Attention_QKNorm_RoPE(nn.Module):
# def __init__(
# self,
# dim: int,
# num_heads: int = 8,
# qkv_bias: bool = False,
# qk_norm: bool = False,
# attn_drop: float = 0.0,
# proj_drop: float = 0.0,
# norm_layer: nn.Module = LlamaRMSNorm,
# enable_flashattn: bool = False,
# rope=None,
# ) -> None:
# super().__init__()
# assert dim % num_heads == 0, "dim should be divisible by num_heads"
# self.dim = dim
# self.num_heads = num_heads
# self.head_dim = dim // num_heads
# self.scale = self.head_dim**-0.5
# self.enable_flashattn = enable_flashattn
# self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
# self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
# self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
# self.attn_drop = nn.Dropout(attn_drop)
# self.proj = nn.Linear(dim, dim)
# self.proj_drop = nn.Dropout(proj_drop)
# self.rotary_emb = rope
# def forward(self, x: torch.Tensor) -> torch.Tensor:
# B, N, C = x.shape
# qkv = self.qkv(x)
# qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
# if self.enable_flashattn:
# qkv_permute_shape = (2, 0, 1, 3, 4)
# else:
# qkv_permute_shape = (2, 0, 3, 1, 4)
# qkv = qkv.view(qkv_shape).permute(qkv_permute_shape)
# q, k, v = qkv.unbind(0)
# #ipdb.set_trace()
# if self.rotary_emb is not None:
# q = self.rotary_emb(q)
# k = self.rotary_emb(k)
# #ipdb.set_trace()
# q, k = self.q_norm(q), self.k_norm(k)
# #ipdb.set_trace()
# if self.enable_flashattn:
# from flash_attn import flash_attn_func
# x = flash_attn_func(
# q,
# k,
# v,
# dropout_p=self.attn_drop.p if self.training else 0.0,
# softmax_scale=self.scale,
# )
# else:
# dtype = q.dtype
# q = q * self.scale
# attn = q @ k.transpose(-2, -1) # translate attn to float32
# attn = attn.to(torch.float32)
# attn = attn.softmax(dim=-1)
# attn = attn.to(dtype) # cast back attn to original dtype
# attn = self.attn_drop(attn)
# x = attn @ v
# x_output_shape = (B, N, C)
# if not self.enable_flashattn:
# x = x.transpose(1, 2)
# x = x.reshape(x_output_shape)
# x = self.proj(x)
# x = self.proj_drop(x)
# return x
# class MaskedSelfAttention(nn.Module):
# def __init__(
# self,
# dim: int,
# num_heads: int = 8,
# qkv_bias: bool = False,
# qk_norm: bool = False,
# attn_drop: float = 0.0,
# proj_drop: float = 0.0,
# norm_layer: nn.Module = LlamaRMSNorm,
# enable_flashattn: bool = False,
# rope=None,
# ) -> None:
# super().__init__()
# assert dim % num_heads == 0, "dim should be divisible by num_heads"
# self.dim = dim
# self.num_heads = num_heads
# self.head_dim = dim // num_heads
# self.scale = self.head_dim**-0.5
# self.enable_flashattn = enable_flashattn
# self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
# self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
# self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
# self.attn_drop = nn.Dropout(attn_drop)
# self.proj = nn.Linear(dim, dim)
# self.proj_drop = nn.Dropout(proj_drop)
# self.rotary_emb = rope
# def forward(self, x, mask):
# B, N, C = x.shape
# qkv = self.qkv(x)
# qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
# qkv_permute_shape = (2, 0, 3, 1, 4)
# qkv = qkv.view(qkv_shape).permute(qkv_permute_shape)
# q, k, v = qkv.unbind(0) # B H N C
# #ipdb.set_trace()
# if self.rotary_emb is not None:
# q = self.rotary_emb(q)
# k = self.rotary_emb(k)
# #ipdb.set_trace()
# q, k = self.q_norm(q), self.k_norm(k)
# #ipdb.set_trace()
# mask = mask.unsqueeze(1).unsqueeze(1).repeat(1, self.num_heads, 1, 1).to(torch.float32) # B H 1 N
# dtype = q.dtype
# q = q * self.scale
# attn = q @ k.transpose(-2, -1) # translate attn to float32
# attn = attn.to(torch.float32)
# attn = attn.masked_fill(mask == 0, -1e9)
# attn = attn.softmax(dim=-1)
# attn = attn.to(dtype) # cast back attn to original dtype
# attn = self.attn_drop(attn)
# x = attn @ v
# x_output_shape = (B, N, C)
# x = x.transpose(1, 2)
# x = x.reshape(x_output_shape)
# x = self.proj(x)
# x = self.proj_drop(x)
# return x
# class SeqParallelAttention(Attention):
# def __init__(
# self,
# dim: int,
# num_heads: int = 8,
# qkv_bias: bool = False,
# qk_norm: bool = False,
# attn_drop: float = 0.0,
# proj_drop: float = 0.0,
# norm_layer: nn.Module = nn.LayerNorm,
# enable_flashattn: bool = False,
# ) -> None:
# super().__init__(
# dim=dim,
# num_heads=num_heads,
# qkv_bias=qkv_bias,
# qk_norm=qk_norm,
# attn_drop=attn_drop,
# proj_drop=proj_drop,
# norm_layer=norm_layer,
# enable_flashattn=enable_flashattn,
# )
# def forward(self, x: torch.Tensor) -> torch.Tensor:
# B, N, C = x.shape # for sequence parallel here, the N is a local sequence length
# qkv = self.qkv(x)
# qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
# qkv = qkv.view(qkv_shape)
# sp_group = get_sequence_parallel_group()
# # apply all_to_all to gather sequence and split attention heads
# # [B, SUB_N, 3, NUM_HEAD, HEAD_DIM] -> [B, N, 3, NUM_HEAD_PER_DEVICE, HEAD_DIM]
# qkv = all_to_all(qkv, sp_group, scatter_dim=3, gather_dim=1)
# if self.enable_flashattn:
# qkv_permute_shape = (2, 0, 1, 3, 4) # [3, B, N, NUM_HEAD_PER_DEVICE, HEAD_DIM]
# else:
# qkv_permute_shape = (2, 0, 3, 1, 4) # [3, B, NUM_HEAD_PER_DEVICE, N, HEAD_DIM]
# qkv = qkv.permute(qkv_permute_shape)
# q, k, v = qkv.unbind(0)
# q, k = self.q_norm(q), self.k_norm(k)
# if self.enable_flashattn:
# from flash_attn import flash_attn_func
# x = flash_attn_func(
# q,
# k,
# v,
# dropout_p=self.attn_drop.p if self.training else 0.0,
# softmax_scale=self.scale,
# )
# else:
# dtype = q.dtype
# q = q * self.scale
# attn = q @ k.transpose(-2, -1) # translate attn to float32
# attn = attn.to(torch.float32)
# attn = attn.softmax(dim=-1)
# attn = attn.to(dtype) # cast back attn to original dtype
# attn = self.attn_drop(attn)
# x = attn @ v
# if not self.enable_flashattn:
# x = x.transpose(1, 2)
# # apply all to all to gather back attention heads and split sequence
# # [B, N, NUM_HEAD_PER_DEVICE, HEAD_DIM] -> [B, SUB_N, NUM_HEAD, HEAD_DIM]
# x = all_to_all(x, sp_group, scatter_dim=1, gather_dim=2)
# # reshape outputs back to [B, N, C]
# x_output_shape = (B, N, C)
# x = x.reshape(x_output_shape)
# x = self.proj(x)
# x = self.proj_drop(x)
# return x
# class MultiHeadCrossAttention(nn.Module):
# def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0):
# super(MultiHeadCrossAttention, self).__init__()
# assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
# self.d_model = d_model
# self.num_heads = num_heads
# self.head_dim = d_model // num_heads
# self.q_linear = nn.Linear(d_model, d_model)
# self.kv_linear = nn.Linear(d_model, d_model * 2)
# self.attn_drop = nn.Dropout(attn_drop)
# self.proj = nn.Linear(d_model, d_model)
# self.proj_drop = nn.Dropout(proj_drop)
# def forward(self, x, cond, mask=None, i=0, t=0):
# # query/value: img tokens; key: condition; mask: if padding tokens
# B, N, C = x.shape
# q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
# kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
# k, v = kv.unbind(2)
# #ipdb.set_trace()
# attn_bias = None
# if mask is not None:
# attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask)
# # vis
# print(t[0].item())
# if i == 27 and t[0].item() >= 480.0 and t[0].item() <= 500.0:
# q1 = q[:, :N, :, :].squeeze(0) # S H C
# q2 = q[:, N:, :, :].squeeze(0) # S H C
# k1 = k[:, :mask[0], :, :].squeeze(0) # L H C
# k2 = k[:, mask[0]:, :, :].squeeze(0) # L H C
# get_attn_mask(q1, k1, 1)
# get_attn_mask(q2, k2, 2)
# # vis
# x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
# #ipdb.set_trace()
# x = x.view(B, -1, C)
# x = self.proj(x)
# x = self.proj_drop(x)
# return x
# class MaskedMultiHeadCrossAttention(nn.Module):
# def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0):
# super(MaskedMultiHeadCrossAttention, self).__init__()
# assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
# self.d_model = d_model
# self.num_heads = num_heads
# self.head_dim = d_model // num_heads
# self.q_linear = nn.Linear(d_model, d_model)
# self.kv_linear = nn.Linear(d_model, d_model * 2)
# self.attn_drop = nn.Dropout(attn_drop)
# self.proj = nn.Linear(d_model, d_model)
# self.proj_drop = nn.Dropout(proj_drop)
# def forward(self, x, cond, mask=None):
# # query/value: img tokens; key: condition; mask: if padding tokens
# B, S, C = x.shape
# L = cond.shape[1]
# q = self.q_linear(x).view(B, S, self.num_heads, self.head_dim)
# kv = self.kv_linear(cond).view(B, L, 2, self.num_heads, self.head_dim)
# k, v = kv.unbind(2)
# #ipdb.set_trace()
# attn_bias = None
# if mask is not None:
# attn_bias = mask.unsqueeze(1).unsqueeze(1).repeat(1, self.num_heads, S, 1).to(q.dtype) # B H S L
# exp = -1e9
# attn_bias[attn_bias==0] = exp
# attn_bias[attn_bias==1] = 0
# x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
# #ipdb.set_trace()
# x = x.view(B, -1, C)
# x = self.proj(x)
# x = self.proj_drop(x)
# return x
# class LongShortMultiHeadCrossAttention(nn.Module):
# def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0):
# super(LongShortMultiHeadCrossAttention, self).__init__()
# assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
# self.d_model = d_model
# self.num_heads = num_heads
# self.head_dim = d_model // num_heads
# self.q_linear = nn.Linear(d_model, d_model)
# self.kv_linear = nn.Linear(d_model, d_model * 2)
# self.attn_drop = nn.Dropout(attn_drop)
# self.proj = nn.Linear(d_model, d_model)
# self.proj_drop = nn.Dropout(proj_drop)
# def forward(self, x, cond, mask=None):
# # query/value: img tokens; key: condition; mask: if padding tokens
# B, N, C = x.shape
# M = cond.shape[1]
# q = self.q_linear(x).view(B, N, self.num_heads, self.head_dim)
# kv = self.kv_linear(cond).view(B, M, 2, self.num_heads, self.head_dim)
# k, v = kv.unbind(2)
# attn_bias = None
# if mask is not None:
# attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask)
# x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
# x = x.view(B, N, C)
# x = self.proj(x)
# x = self.proj_drop(x)
# return x
# class MultiHeadV2TCrossAttention(nn.Module):
# def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0):
# super(MultiHeadV2TCrossAttention, self).__init__()
# assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
# self.d_model = d_model
# self.num_heads = num_heads
# self.head_dim = d_model // num_heads
# self.q_linear = nn.Linear(d_model, d_model)
# self.kv_linear = nn.Linear(d_model, d_model * 2)
# self.attn_drop = nn.Dropout(attn_drop)
# self.proj = nn.Linear(d_model, d_model)
# self.proj_drop = nn.Dropout(proj_drop)
# def forward(self, x, cond, mask=None):
# # query/value: condition; key: img tokens; mask: if padding tokens
# B, N, C = cond.shape
# q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
# kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
# k, v = kv.unbind(2)
# #ipdb.set_trace()
# attn_bias = None
# if mask is not None:
# attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens(mask, [N] * B)
# x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
# #ipdb.set_trace()
# x = x.view(B, -1, C)
# x = self.proj(x)
# x = self.proj_drop(x)
# return x
# class MultiHeadT2VCrossAttention(nn.Module):
# def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0):
# super(MultiHeadT2VCrossAttention, self).__init__()
# assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
# self.d_model = d_model
# self.num_heads = num_heads
# self.head_dim = d_model // num_heads
# self.q_linear = nn.Linear(d_model, d_model)
# self.kv_linear = nn.Linear(d_model, d_model * 2)
# self.attn_drop = nn.Dropout(attn_drop)
# self.proj = nn.Linear(d_model, d_model)
# self.proj_drop = nn.Dropout(proj_drop)
# def forward(self, x, cond, mask=None):
# # query/value: img tokens; key: condition; mask: if padding tokens
# #ipdb.set_trace()
# B, T, N, C = x.shape
# x = rearrange(x, 'B T N C -> (B T) N C')
# q = self.q_linear(x)
# q = rearrange(q, '(B T) N C -> B T N C', T=T)
# q = q.view(1, -1, self.num_heads, self.head_dim) # 1(B T N) H C
# kv = self.kv_linear(cond)
# kv = kv.view(1, -1, 2, self.num_heads, self.head_dim) # 1 N 2 H C
# k, v = kv.unbind(2)
# #ipdb.set_trace()
# attn_bias = None
# if mask is not None:
# #mask = [m for m in mask for _ in range(T)]
# attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * (B*T), mask)
# x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
# #ipdb.set_trace()
# x = x.view(B, T, N, C)
# x = rearrange(x, 'B T N C -> (B T) N C')
# x = self.proj(x)
# x = self.proj_drop(x)
# x = rearrange(x, '(B T) N C -> B T N C', T=T)
# return x
# class FormerMultiHeadV2TCrossAttention(nn.Module):
# def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0):
# super(FormerMultiHeadV2TCrossAttention, self).__init__()
# assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
# self.d_model = d_model
# self.num_heads = num_heads
# self.head_dim = d_model // num_heads
# self.q_linear = nn.Linear(d_model, d_model)
# self.kv_linear = nn.Linear(d_model, d_model * 2)
# self.attn_drop = nn.Dropout(attn_drop)
# self.proj = nn.Linear(d_model, d_model)
# self.proj_drop = nn.Dropout(proj_drop)
# def forward(self, x, cond, mask=None):
# # x: text tokens; cond: img tokens; mask: if padding tokens
# #ipdb.set_trace()
# _, N, C = x.shape # 1 N C
# B, T, _, _ = cond.shape
# cond = rearrange(cond, 'B T N C -> (B T) N C')
# q = self.q_linear(x)
# q = q.view(1, -1, self.num_heads, self.head_dim) # 1 N H C
# kv = self.kv_linear(cond)
# kv = rearrange(kv, '(B T) N C -> B T N C', B=B)
# M = kv.shape[2] # M = H * W
# former_frame_index = torch.arange(T) - 1
# former_frame_index[0] = 0
# #ipdb.set_trace()
# former_kv = kv[:, former_frame_index]
# former_kv = former_kv.view(1, -1, 2, self.num_heads, self.head_dim) # 1(B T N) 2 H C
# former_k, former_v = former_kv.unbind(2)
# #ipdb.set_trace()
# attn_bias = None
# if mask is not None:
# #mask = [m for m in mask for _ in range(T)]
# attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens(mask, [M] * (B*T))
# x = xformers.ops.memory_efficient_attention(q, former_k, former_v, p=self.attn_drop.p, attn_bias=attn_bias)
# #ipdb.set_trace()
# x = x.view(1, -1, C)
# x = self.proj(x)
# x = self.proj_drop(x)
# return x
# class LatterMultiHeadV2TCrossAttention(nn.Module):
# def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0):
# super(LatterMultiHeadV2TCrossAttention, self).__init__()
# assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
# self.d_model = d_model
# self.num_heads = num_heads
# self.head_dim = d_model // num_heads
# self.q_linear = nn.Linear(d_model, d_model)
# self.kv_linear = nn.Linear(d_model, d_model * 2)
# self.attn_drop = nn.Dropout(attn_drop)
# self.proj = nn.Linear(d_model, d_model)
# self.proj_drop = nn.Dropout(proj_drop)
# def forward(self, x, cond, mask=None):
# # x: text tokens; cond: img tokens; mask: if padding tokens
# #ipdb.set_trace()
# _, N, C = x.shape # 1 N C
# B, T, _, _ = cond.shape
# cond = rearrange(cond, 'B T N C -> (B T) N C')
# q = self.q_linear(x)
# q = q.view(1, -1, self.num_heads, self.head_dim) # 1 N H C
# kv = self.kv_linear(cond)
# kv = rearrange(kv, '(B T) N C -> B T N C', T=T)
# M = kv.shape[2] # M = H * W
# latter_frame_index = torch.arange(T) + 1
# latter_frame_index[-1] = T - 1
# #ipdb.set_trace()
# latter_kv = kv[:, latter_frame_index]
# latter_kv = latter_kv.view(1, -1, 2, self.num_heads, self.head_dim) # 1(B T N) 2 H C
# latter_k, latter_v = latter_kv.unbind(2)
# #ipdb.set_trace()
# attn_bias = None
# if mask is not None:
# # mask = [m for m in mask for _ in range(T)]
# attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens(mask, [M] * (B*T))
# x = xformers.ops.memory_efficient_attention(q, latter_k, latter_v, p=self.attn_drop.p, attn_bias=attn_bias)
# #ipdb.set_trace()
# x = x.view(1, -1, C)
# x = self.proj(x)
# x = self.proj_drop(x)
# return x
# class SeqParallelMultiHeadCrossAttention(MultiHeadCrossAttention):
# def __init__(
# self,
# d_model,
# num_heads,
# attn_drop=0.0,
# proj_drop=0.0,
# ):
# super().__init__(d_model=d_model, num_heads=num_heads, attn_drop=attn_drop, proj_drop=proj_drop)
# def forward(self, x, cond, mask=None):
# # query/value: img tokens; key: condition; mask: if padding tokens
# sp_group = get_sequence_parallel_group()
# sp_size = dist.get_world_size(sp_group)
# B, SUB_N, C = x.shape
# N = SUB_N * sp_size
# # shape:
# # q, k, v: [B, SUB_N, NUM_HEADS, HEAD_DIM]
# q = self.q_linear(x).view(B, -1, self.num_heads, self.head_dim)
# kv = self.kv_linear(cond).view(B, -1, 2, self.num_heads, self.head_dim)
# k, v = kv.unbind(2)
# # apply all_to_all to gather sequence and split attention heads
# q = all_to_all(q, sp_group, scatter_dim=2, gather_dim=1)
# k = split_forward_gather_backward(k, get_sequence_parallel_group(), dim=2, grad_scale="down")
# v = split_forward_gather_backward(v, get_sequence_parallel_group(), dim=2, grad_scale="down")
# q = q.view(1, -1, self.num_heads // sp_size, self.head_dim)
# k = k.view(1, -1, self.num_heads // sp_size, self.head_dim)
# v = v.view(1, -1, self.num_heads // sp_size, self.head_dim)
# # compute attention
# attn_bias = None
# if mask is not None:
# attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask)
# x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
# # apply all to all to gather back attention heads and scatter sequence
# x = x.view(B, -1, self.num_heads // sp_size, self.head_dim)
# x = all_to_all(x, sp_group, scatter_dim=1, gather_dim=2)
# # apply output projection
# x = x.view(B, -1, C)
# x = self.proj(x)
# x = self.proj_drop(x)
# return x
# class FinalLayer(nn.Module):
# """
# The final layer of DiT.
# """
# def __init__(self, hidden_size, num_patch, out_channels):
# super().__init__()
# self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
# self.linear = nn.Linear(hidden_size, num_patch * out_channels, bias=True)
# self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
# def forward(self, x, c):
# shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
# x = modulate(self.norm_final, x, shift, scale)
# x = self.linear(x)
# return x
# class T2IFinalLayer(nn.Module):
# """
# The final layer of PixArt.
# """
# def __init__(self, hidden_size, num_patch, out_channels):
# super().__init__()
# self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
# self.linear = nn.Linear(hidden_size, num_patch * out_channels, bias=True)
# self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size**0.5)
# self.out_channels = out_channels
# def forward(self, x, t):
# shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1)
# x = t2i_modulate(self.norm_final(x), shift, scale)
# x = self.linear(x)
# return x
# # ===============================================
# # Embedding Layers for Timesteps and Class Labels
# # ===============================================
# class TimestepEmbedder(nn.Module):
# """
# Embeds scalar timesteps into vector representations.
# """
# def __init__(self, hidden_size, frequency_embedding_size=256):
# super().__init__()
# self.mlp = nn.Sequential(
# nn.Linear(frequency_embedding_size, hidden_size, bias=True),
# nn.SiLU(),
# nn.Linear(hidden_size, hidden_size, bias=True),
# )
# self.frequency_embedding_size = frequency_embedding_size
# @staticmethod
# def timestep_embedding(t, dim, max_period=10000):
# """
# Create sinusoidal timestep embeddings.
# :param t: a 1-D Tensor of N indices, one per batch element.
# These may be fractional.
# :param dim: the dimension of the output.
# :param max_period: controls the minimum frequency of the embeddings.
# :return: an (N, D) Tensor of positional embeddings.
# """
# # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
# half = dim // 2
# freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half)
# freqs = freqs.to(device=t.device)
# args = t[:, None].float() * freqs[None]
# embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
# if dim % 2:
# embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
# return embedding
# def forward(self, t, dtype):
# t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
# if t_freq.dtype != dtype:
# t_freq = t_freq.to(dtype)
# t_emb = self.mlp(t_freq)
# return t_emb
# class LabelEmbedder(nn.Module):
# """
# Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
# """
# def __init__(self, num_classes, hidden_size, dropout_prob):
# super().__init__()
# use_cfg_embedding = dropout_prob > 0
# self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
# self.num_classes = num_classes
# self.dropout_prob = dropout_prob
# def token_drop(self, labels, force_drop_ids=None):
# """
# Drops labels to enable classifier-free guidance.
# """
# if force_drop_ids is None:
# drop_ids = torch.rand(labels.shape[0]).cuda() < self.dropout_prob
# else:
# drop_ids = force_drop_ids == 1
# labels = torch.where(drop_ids, self.num_classes, labels)
# return labels
# def forward(self, labels, train, force_drop_ids=None):
# use_dropout = self.dropout_prob > 0
# if (train and use_dropout) or (force_drop_ids is not None):
# labels = self.token_drop(labels, force_drop_ids)
# return self.embedding_table(labels)
# class SizeEmbedder(TimestepEmbedder):
# """
# Embeds scalar timesteps into vector representations.
# """
# def __init__(self, hidden_size, frequency_embedding_size=256):
# super().__init__(hidden_size=hidden_size, frequency_embedding_size=frequency_embedding_size)
# self.mlp = nn.Sequential(
# nn.Linear(frequency_embedding_size, hidden_size, bias=True),
# nn.SiLU(),
# nn.Linear(hidden_size, hidden_size, bias=True),
# )
# self.frequency_embedding_size = frequency_embedding_size
# self.outdim = hidden_size
# def forward(self, s, bs):
# if s.ndim == 1:
# s = s[:, None]
# assert s.ndim == 2
# if s.shape[0] != bs:
# s = s.repeat(bs // s.shape[0], 1)
# assert s.shape[0] == bs
# b, dims = s.shape[0], s.shape[1]
# s = rearrange(s, "b d -> (b d)")
# s_freq = self.timestep_embedding(s, self.frequency_embedding_size).to(self.dtype)
# s_emb = self.mlp(s_freq)
# s_emb = rearrange(s_emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim)
# return s_emb
# @property
# def dtype(self):
# return next(self.parameters()).dtype
# class CaptionEmbedder(nn.Module):
# """
# Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
# """
# def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate="tanh"), token_num=120):
# super().__init__()
# self.y_proj = Mlp(
# in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer, drop=0
# )
# self.register_buffer("y_embedding", nn.Parameter(torch.randn(token_num, in_channels) / in_channels**0.5))
# self.uncond_prob = uncond_prob
# def token_drop(self, caption, force_drop_ids=None):
# """
# Drops labels to enable classifier-free guidance.
# """
# if force_drop_ids is None:
# drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob
# else:
# drop_ids = force_drop_ids == 1
# caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption)
# return caption
# def forward(self, caption, train, force_drop_ids=None):
# if train:
# assert caption.shape[2:] == self.y_embedding.shape
# use_dropout = self.uncond_prob > 0
# if (train and use_dropout) or (force_drop_ids is not None):
# caption = self.token_drop(caption, force_drop_ids)
# caption = self.y_proj(caption)
# return caption
# # ===============================================
# # Sine/Cosine Positional Embedding Functions
# # ===============================================
# # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
# def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, scale=1.0, base_size=None):
# """
# grid_size: int of the grid height and width
# return:
# pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
# """
# if not isinstance(grid_size, tuple):
# grid_size = (grid_size, grid_size)
# grid_h = np.arange(grid_size[0], dtype=np.float32) / scale
# grid_w = np.arange(grid_size[1], dtype=np.float32) / scale
# if base_size is not None:
# grid_h *= base_size / grid_size[0]
# grid_w *= base_size / grid_size[1]
# grid = np.meshgrid(grid_w, grid_h) # here w goes first
# grid = np.stack(grid, axis=0)
# grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
# pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
# if cls_token and extra_tokens > 0:
# pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
# return pos_embed
# def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
# assert embed_dim % 2 == 0
# # use half of dimensions to encode grid_h
# emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
# emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
# emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
# return emb
# def get_1d_sincos_pos_embed(embed_dim, length, scale=1.0):
# pos = np.arange(0, length)[..., None] / scale
# return get_1d_sincos_pos_embed_from_grid(embed_dim, pos)
# def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
# """
# embed_dim: output dimension for each position
# pos: a list of positions to be encoded: size (M,)
# out: (M, D)
# """
# assert embed_dim % 2 == 0
# omega = np.arange(embed_dim // 2, dtype=np.float64)
# omega /= embed_dim / 2.0
# omega = 1.0 / 10000**omega # (D/2,)
# pos = pos.reshape(-1) # (M,)
# out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
# emb_sin = np.sin(out) # (M, D/2)
# emb_cos = np.cos(out) # (M, D/2)
# emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
# return emb
# code from timm 0.3.2
import torch
import torch.nn as nn
import math
import warnings
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
# type: (Tensor, float, float, float, float) -> Tensor
r"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
"""
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
def drop_path(x, drop_prob: float = 0., training: bool = False):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
\ No newline at end of file
from .pixart import PixArt, PixArt_XL_2
# Adapted from PixArt
#
# Copyright (C) 2023 PixArt-alpha/PixArt-alpha
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# PixArt: https://github.com/PixArt-alpha/PixArt-alpha
# DiT: https://github.com/facebookresearch/DiT/tree/main
# --------------------------------------------------------
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from timm.models.layers import DropPath
from timm.models.vision_transformer import Mlp
# from .builder import MODELS
from opensora.acceleration.checkpoint import auto_grad_checkpoint
from opensora.models.layers.blocks import (
Attention,
CaptionEmbedder,
MultiHeadCrossAttention,
PatchEmbed3D,
SeqParallelAttention,
SeqParallelMultiHeadCrossAttention,
SizeEmbedder,
T2IFinalLayer,
TimestepEmbedder,
approx_gelu,
get_1d_sincos_pos_embed,
get_2d_sincos_pos_embed,
get_layernorm,
t2i_modulate,
)
from opensora.registry import MODELS
from opensora.utils.ckpt_utils import load_checkpoint
class PixArtBlock(nn.Module):
"""
A PixArt block with adaptive layer norm (adaLN-single) conditioning.
"""
def __init__(
self,
hidden_size,
num_heads,
mlp_ratio=4.0,
drop_path=0.0,
enable_flashattn=False,
enable_layernorm_kernel=False,
enable_sequence_parallelism=False,
):
super().__init__()
self.hidden_size = hidden_size
self.enable_flashattn = enable_flashattn
self._enable_sequence_parallelism = enable_sequence_parallelism
if enable_sequence_parallelism:
self.attn_cls = SeqParallelAttention
self.mha_cls = SeqParallelMultiHeadCrossAttention
else:
self.attn_cls = Attention
self.mha_cls = MultiHeadCrossAttention
self.norm1 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)
self.attn = self.attn_cls(
hidden_size,
num_heads=num_heads,
qkv_bias=True,
enable_flashattn=enable_flashattn,
)
self.cross_attn = self.mha_cls(hidden_size, num_heads)
self.norm2 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)
self.mlp = Mlp(
in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5)
def forward(self, x, y, t, mask=None):
B, N, C = x.shape
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.scale_shift_table[None] + t.reshape(B, 6, -1)
).chunk(6, dim=1)
x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa)).reshape(B, N, C))
x = x + self.cross_attn(x, y, mask)
x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
return x
@MODELS.register_module()
class PixArt(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(
self,
input_size=(1, 32, 32),
in_channels=4,
patch_size=(1, 2, 2),
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4.0,
class_dropout_prob=0.1,
pred_sigma=True,
drop_path: float = 0.0,
no_temporal_pos_emb=False,
caption_channels=4096,
model_max_length=120,
dtype=torch.float32,
freeze=None,
space_scale=1.0,
time_scale=1.0,
enable_flashattn=False,
enable_layernorm_kernel=False,
):
super().__init__()
self.pred_sigma = pred_sigma
self.in_channels = in_channels
self.out_channels = in_channels * 2 if pred_sigma else in_channels
self.hidden_size = hidden_size
self.patch_size = patch_size
self.input_size = input_size
num_patches = np.prod([input_size[i] // patch_size[i] for i in range(3)])
self.num_patches = num_patches
self.num_temporal = input_size[0] // patch_size[0]
self.num_spatial = num_patches // self.num_temporal
self.base_size = int(np.sqrt(self.num_spatial))
self.num_heads = num_heads
self.dtype = dtype
self.no_temporal_pos_emb = no_temporal_pos_emb
self.depth = depth
self.mlp_ratio = mlp_ratio
self.enable_flashattn = enable_flashattn
self.enable_layernorm_kernel = enable_layernorm_kernel
self.space_scale = space_scale
self.time_scale = time_scale
self.x_embedder = PatchEmbed3D(patch_size, in_channels, hidden_size)
self.t_embedder = TimestepEmbedder(hidden_size)
self.t_block = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
self.y_embedder = CaptionEmbedder(
in_channels=caption_channels,
hidden_size=hidden_size,
uncond_prob=class_dropout_prob,
act_layer=approx_gelu,
token_num=model_max_length,
)
self.register_buffer("pos_embed", self.get_spatial_pos_embed())
self.register_buffer("pos_embed_temporal", self.get_temporal_pos_embed())
drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList(
[
PixArtBlock(
hidden_size,
num_heads,
mlp_ratio=mlp_ratio,
drop_path=drop_path[i],
enable_flashattn=enable_flashattn,
enable_layernorm_kernel=enable_layernorm_kernel,
)
for i in range(depth)
]
)
self.final_layer = T2IFinalLayer(hidden_size, np.prod(self.patch_size), self.out_channels)
self.initialize_weights()
if freeze is not None:
assert freeze in ["text"]
if freeze == "text":
self.freeze_text()
def forward(self, x, timestep, y, mask=None):
"""
Forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) tensor of class labels
"""
x = x.to(self.dtype)
timestep = timestep.to(self.dtype)
y = y.to(self.dtype)
# embedding
x = self.x_embedder(x) # (B, N, D)
x = rearrange(x, "b (t s) d -> b t s d", t=self.num_temporal, s=self.num_spatial)
x = x + self.pos_embed
if not self.no_temporal_pos_emb:
x = rearrange(x, "b t s d -> b s t d")
x = x + self.pos_embed_temporal
x = rearrange(x, "b s t d -> b (t s) d")
else:
x = rearrange(x, "b t s d -> b (t s) d")
t = self.t_embedder(timestep, dtype=x.dtype) # (N, D)
t0 = self.t_block(t)
y = self.y_embedder(y, self.training) # (N, 1, L, D)
if mask is not None:
if mask.shape[0] != y.shape[0]:
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
mask = mask.squeeze(1).squeeze(1)
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
else:
y_lens = [y.shape[2]] * y.shape[0]
y = y.squeeze(1).view(1, -1, x.shape[-1])
# blocks
for block in self.blocks:
x = auto_grad_checkpoint(block, x, y, t0, y_lens)
# final process
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = self.unpatchify(x) # (N, out_channels, H, W)
# cast to float32 for better accuracy
x = x.to(torch.float32)
return x
def unpatchify(self, x):
c = self.out_channels
t, h, w = [self.input_size[i] // self.patch_size[i] for i in range(3)]
pt, ph, pw = self.patch_size
x = x.reshape(shape=(x.shape[0], t, h, w, pt, ph, pw, c))
x = rearrange(x, "n t h w r p q c -> n c t r h p w q")
imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
return imgs
def get_spatial_pos_embed(self, grid_size=None):
if grid_size is None:
grid_size = self.input_size[1:]
pos_embed = get_2d_sincos_pos_embed(
self.hidden_size,
(grid_size[0] // self.patch_size[1], grid_size[1] // self.patch_size[2]),
scale=self.space_scale,
base_size=self.base_size,
)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)
return pos_embed
def get_temporal_pos_embed(self):
pos_embed = get_1d_sincos_pos_embed(
self.hidden_size,
self.input_size[0] // self.patch_size[0],
scale=self.time_scale,
)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)
return pos_embed
def freeze_text(self):
for n, p in self.named_parameters():
if "cross_attn" in n:
p.requires_grad = False
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
nn.init.normal_(self.t_block[1].weight, std=0.02)
# Initialize caption embedding MLP:
nn.init.normal_(self.y_embedder.y_proj.fc1.weight, std=0.02)
nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02)
# Zero-out adaLN modulation layers in PixArt blocks:
for block in self.blocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
@MODELS.register_module()
class PixArtMS(PixArt):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.hidden_size % 3 == 0, "hidden_size must be divisible by 3"
self.csize_embedder = SizeEmbedder(self.hidden_size // 3)
self.ar_embedder = SizeEmbedder(self.hidden_size // 3)
def forward(self, x, timestep, y, mask=None, data_info=None):
"""
Forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) tensor of class labels
"""
x = x.to(self.dtype)
timestep = timestep.to(self.dtype)
y = y.to(self.dtype)
c_size = data_info["hw"]
ar = data_info["ar"]
pos_embed = self.get_spatial_pos_embed((x.shape[-2], x.shape[-1])).to(x.dtype)
# embedding
x = self.x_embedder(x) # (B, N, D)
x = rearrange(x, "b (t s) d -> b t s d", t=self.num_temporal, s=self.num_spatial)
x = x + pos_embed.to(x.device)
if not self.no_temporal_pos_emb:
x = rearrange(x, "b t s d -> b s t d")
x = x + self.pos_embed_temporal
x = rearrange(x, "b s t d -> b (t s) d")
else:
x = rearrange(x, "b t s d -> b (t s) d")
t = self.t_embedder(timestep, dtype=x.dtype) # (N, D)
B = x.shape[0]
csize = self.csize_embedder(c_size, B)
ar = self.ar_embedder(ar, B)
t = t + torch.cat([csize, ar], dim=1)
t0 = self.t_block(t)
y = self.y_embedder(y, self.training) # (N, 1, L, D)
if mask is not None:
if mask.shape[0] != y.shape[0]:
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
mask = mask.squeeze(1).squeeze(1)
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
else:
y_lens = [y.shape[2]] * y.shape[0]
y = y.squeeze(1).view(1, -1, x.shape[-1])
# blocks
for block in self.blocks:
x = block(x, y, t0, y_lens)
# final process
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = self.unpatchify(x) # (N, out_channels, H, W)
# cast to float32 for better accuracy
x = x.to(torch.float32)
return x
@MODELS.register_module("PixArt-XL/2")
def PixArt_XL_2(from_pretrained=None, **kwargs):
model = PixArt(depth=28, hidden_size=1152, patch_size=(1, 2, 2), num_heads=16, **kwargs)
if from_pretrained is not None:
load_checkpoint(model, from_pretrained)
return model
@MODELS.register_module("PixArtMS-XL/2")
def PixArtMS_XL_2(from_pretrained=None, **kwargs):
model = PixArtMS(depth=28, hidden_size=1152, patch_size=(1, 2, 2), num_heads=16, **kwargs)
if from_pretrained is not None:
load_checkpoint(model, from_pretrained)
return model
from .stdit import STDiT
#from .stdit_ftf import STDiT_FTF
from .stdit_qknorm_rope import STDiT_QKNorm_RoPE
#from .stdit_denseadd import STDiT_DenseAdd
#from .stdit_densecat import STDiT_DenseCat
#from .stdit_it2v import STDiT_IT2V
#from .stdit_densewmean import STDiT_DenseWmean
#from .stdit_longshort import STDiT_LS
#from .stdit_ttca import STDiT_TTCA
#from .stdit_densecat_norm import STDiT_DenseCatNorm
#from .ustdit import USTDiT
#from .stdit_densesecat import STDiT_DenseSECat
#from .stdit_densecat_mmdit import STDiT_DenseMM
from .stdit_mmdit import STDiT_MMDiT
from .stdit_freq import STDiT_freq
#from .stdit_densecat_multipos import STDiT_DenseCatMpos
#from .stdit_mmdit_nocross import STDiT_MMDiT_Nocross
#from .stdit_mmdit_tempsplit import STDiT_MMDiT_TempSplit
#from .stdit_mmdit_qk import STDiT_MMDiTQK
#from .stdit_mmdit_small import STDiT_MMDiT_Small
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from einops import rearrange
from timm.models.layers import DropPath
from timm.models.vision_transformer import Mlp
from opensora.acceleration.checkpoint import auto_grad_checkpoint
from opensora.acceleration.communications import gather_forward_split_backward, split_forward_gather_backward
from opensora.acceleration.parallel_states import get_sequence_parallel_group
from opensora.models.layers.blocks import (
Attention,
CaptionEmbedder,
MultiHeadCrossAttention,
PatchEmbed3D,
SeqParallelAttention,
SeqParallelMultiHeadCrossAttention,
T2IFinalLayer,
TimestepEmbedder,
approx_gelu,
get_1d_sincos_pos_embed,
get_2d_sincos_pos_embed,
get_layernorm,
t2i_modulate,
)
from opensora.registry import MODELS
from opensora.utils.ckpt_utils import load_checkpoint
# import ipdb
class STDiTBlock(nn.Module):
def __init__(
self,
hidden_size,
num_heads,
d_s=None,
d_t=None,
mlp_ratio=4.0,
drop_path=0.0,
enable_flashattn=False,
enable_layernorm_kernel=False,
enable_sequence_parallelism=False,
):
super().__init__()
self.hidden_size = hidden_size
self.enable_flashattn = enable_flashattn
self._enable_sequence_parallelism = enable_sequence_parallelism
if enable_sequence_parallelism:
self.attn_cls = SeqParallelAttention
self.mha_cls = SeqParallelMultiHeadCrossAttention
else: # here
self.attn_cls = Attention
self.mha_cls = MultiHeadCrossAttention
self.norm1 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)
self.attn = self.attn_cls(
hidden_size,
num_heads=num_heads,
qkv_bias=True,
enable_flashattn=enable_flashattn,
)
self.cross_attn = self.mha_cls(hidden_size, num_heads)
self.norm2 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)
self.mlp = Mlp(
in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5)
# temporal attention
self.d_s = d_s
self.d_t = d_t
if self._enable_sequence_parallelism:
sp_size = dist.get_world_size(get_sequence_parallel_group())
# make sure d_t is divisible by sp_size
assert d_t % sp_size == 0
self.d_t = d_t // sp_size
self.attn_temp = self.attn_cls(
hidden_size,
num_heads=num_heads,
qkv_bias=True,
enable_flashattn=self.enable_flashattn,
)
def forward(self, x, y, t, mask=None, tpe=None):
B, N, C = x.shape
# ipdb.set_trace()
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.scale_shift_table[None] + t.reshape(B, 6, -1)
).chunk(6, dim=1)
x_m = t2i_modulate(self.norm1(x), shift_msa, scale_msa)
# spatial branch
x_s = rearrange(x_m, "B (T S) C -> (B T) S C", T=self.d_t, S=self.d_s)
x_s = self.attn(x_s)
x_s = rearrange(x_s, "(B T) S C -> B (T S) C", T=self.d_t, S=self.d_s)
x = x + self.drop_path(gate_msa * x_s)
# temporal branch
x_t = rearrange(x, "B (T S) C -> (B S) T C", T=self.d_t, S=self.d_s)
if tpe is not None:
x_t = x_t + tpe
x_t = self.attn_temp(x_t)
x_t = rearrange(x_t, "(B S) T C -> B (T S) C", T=self.d_t, S=self.d_s)
x = x + self.drop_path(gate_msa * x_t)
# cross attn
x = x + self.cross_attn(x, y, mask)
# mlp
x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
return x
@MODELS.register_module()
class STDiT(nn.Module):
def __init__(
self,
input_size=(1, 32, 32),
in_channels=4,
patch_size=(1, 2, 2),
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4.0,
class_dropout_prob=0.1,
pred_sigma=True,
drop_path=0.0,
no_temporal_pos_emb=False,
caption_channels=4096,
model_max_length=120,
dtype=torch.float32,
space_scale=1.0,
time_scale=1.0,
freeze=None,
enable_flashattn=False,
enable_layernorm_kernel=False,
enable_sequence_parallelism=False,
):
super().__init__()
self.pred_sigma = pred_sigma
self.in_channels = in_channels
self.out_channels = in_channels * 2 if pred_sigma else in_channels
self.hidden_size = hidden_size
self.patch_size = patch_size
self.input_size = input_size
num_patches = np.prod([input_size[i] // patch_size[i] for i in range(3)])
self.num_patches = num_patches
self.num_temporal = input_size[0] // patch_size[0]
self.num_spatial = num_patches // self.num_temporal
self.num_heads = num_heads
self.dtype = dtype
self.no_temporal_pos_emb = no_temporal_pos_emb
self.depth = depth
self.mlp_ratio = mlp_ratio
self.enable_flashattn = enable_flashattn
self.enable_layernorm_kernel = enable_layernorm_kernel
self.space_scale = space_scale
self.time_scale = time_scale
self.register_buffer("pos_embed", self.get_spatial_pos_embed())
self.register_buffer("pos_embed_temporal", self.get_temporal_pos_embed())
self.x_embedder = PatchEmbed3D(patch_size, in_channels, hidden_size)
self.t_embedder = TimestepEmbedder(hidden_size)
self.t_block = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
self.y_embedder = CaptionEmbedder(
in_channels=caption_channels,
hidden_size=hidden_size,
uncond_prob=class_dropout_prob,
act_layer=approx_gelu,
token_num=model_max_length,
)
drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)]
self.blocks = nn.ModuleList(
[
STDiTBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=self.mlp_ratio,
drop_path=drop_path[i],
enable_flashattn=self.enable_flashattn,
enable_layernorm_kernel=self.enable_layernorm_kernel,
enable_sequence_parallelism=enable_sequence_parallelism,
d_t=self.num_temporal,
d_s=self.num_spatial,
)
for i in range(self.depth)
]
)
self.final_layer = T2IFinalLayer(hidden_size, np.prod(self.patch_size), self.out_channels)
# init model
self.initialize_weights()
self.initialize_temporal()
if freeze is not None:
assert freeze in ["not_temporal", "text"]
if freeze == "not_temporal":
self.freeze_not_temporal()
elif freeze == "text":
self.freeze_text()
# sequence parallel related configs
self.enable_sequence_parallelism = enable_sequence_parallelism
if enable_sequence_parallelism:
self.sp_rank = dist.get_rank(get_sequence_parallel_group())
else:
self.sp_rank = None
def forward(self, x, timestep, y, mask=None):
"""
Forward pass of STDiT.
Args:
x (torch.Tensor): latent representation of video; of shape [B, C, T, H, W]
timestep (torch.Tensor): diffusion time steps; of shape [B]
y (torch.Tensor): representation of prompts; of shape [B, 1, N_token, C]
mask (torch.Tensor): mask for selecting prompt tokens; of shape [B, N_token]
Returns:
x (torch.Tensor): output latent representation; of shape [B, C, T, H, W]
"""
x = x.to(self.dtype)
timestep = timestep.to(self.dtype)
y = y.to(self.dtype)
# embedding
x = self.x_embedder(x) # [B, N, C]
x = rearrange(x, "B (T S) C -> B T S C", T=self.num_temporal, S=self.num_spatial)
x = x + self.pos_embed
x = rearrange(x, "B T S C -> B (T S) C")
# shard over the sequence dim if sp is enabled
if self.enable_sequence_parallelism:
x = split_forward_gather_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="down")
t = self.t_embedder(timestep, dtype=x.dtype) # [B, C]
t0 = self.t_block(t) # [B, C]
y = self.y_embedder(y, self.training) # [B, 1, N_token, C]
if mask is not None:
if mask.shape[0] != y.shape[0]:
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
mask = mask.squeeze(1).squeeze(1)
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
else:
y_lens = [y.shape[2]] * y.shape[0]
y = y.squeeze(1).view(1, -1, x.shape[-1])
# blocks
for i, block in enumerate(self.blocks):
if i == 0:
if self.enable_sequence_parallelism:
tpe = torch.chunk(
self.pos_embed_temporal, dist.get_world_size(get_sequence_parallel_group()), dim=1
)[self.sp_rank].contiguous()
else:
tpe = self.pos_embed_temporal
else:
tpe = None
x = auto_grad_checkpoint(block, x, y, t0, y_lens, tpe)
if self.enable_sequence_parallelism:
x = gather_forward_split_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="up")
# x.shape: [B, N, C]
# final process
x = self.final_layer(x, t) # [B, N, C=T_p * H_p * W_p * C_out]
x = self.unpatchify(x) # [B, C_out, T, H, W]
# cast to float32 for better accuracy
x = x.to(torch.float32)
return x
def unpatchify(self, x):
"""
Args:
x (torch.Tensor): of shape [B, N, C]
Return:
x (torch.Tensor): of shape [B, C_out, T, H, W]
"""
N_t, N_h, N_w = [self.input_size[i] // self.patch_size[i] for i in range(3)]
T_p, H_p, W_p = self.patch_size
x = rearrange(
x,
"B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)",
N_t=N_t,
N_h=N_h,
N_w=N_w,
T_p=T_p,
H_p=H_p,
W_p=W_p,
C_out=self.out_channels,
)
return x
def unpatchify_old(self, x):
c = self.out_channels
t, h, w = [self.input_size[i] // self.patch_size[i] for i in range(3)]
pt, ph, pw = self.patch_size
x = x.reshape(shape=(x.shape[0], t, h, w, pt, ph, pw, c))
x = rearrange(x, "n t h w r p q c -> n c t r h p w q")
imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
return imgs
def get_spatial_pos_embed(self, grid_size=None):
if grid_size is None:
grid_size = self.input_size[1:]
pos_embed = get_2d_sincos_pos_embed(
self.hidden_size,
(grid_size[0] // self.patch_size[1], grid_size[1] // self.patch_size[2]),
scale=self.space_scale,
)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)
return pos_embed
def get_temporal_pos_embed(self):
pos_embed = get_1d_sincos_pos_embed(
self.hidden_size,
self.input_size[0] // self.patch_size[0],
scale=self.time_scale,
)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)
return pos_embed
def freeze_not_temporal(self):
for n, p in self.named_parameters():
if "attn_temp" not in n:
p.requires_grad = False
def freeze_text(self):
for n, p in self.named_parameters():
if "cross_attn" in n:
p.requires_grad = False
def initialize_temporal(self):
for block in self.blocks:
nn.init.constant_(block.attn_temp.proj.weight, 0)
nn.init.constant_(block.attn_temp.proj.bias, 0)
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
nn.init.normal_(self.t_block[1].weight, std=0.02)
# Initialize caption embedding MLP:
nn.init.normal_(self.y_embedder.y_proj.fc1.weight, std=0.02)
nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02)
# Zero-out adaLN modulation layers in PixArt blocks:
for block in self.blocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
@MODELS.register_module("STDiT-XL/2")
def STDiT_XL_2(from_pretrained=None, **kwargs):
model = STDiT(depth=28, hidden_size=1152, patch_size=(1, 2, 2), num_heads=16, **kwargs)
if from_pretrained is not None:
load_checkpoint(model, from_pretrained)
return model
import re
import torch
import torch.nn as nn
from copy import deepcopy
from torch import Tensor
from torch.nn import Module, Linear, init
from typing import Any, Mapping
import sys
sys.path.append("/home/test/Workspace/ruixie/Open-Sora")
from einops import rearrange
# from diffusion.model.nets import PixArtMSBlock, PixArtMS, PixArt
from opensora.models.stdit.stdit import STDiTBlock, STDiT
from opensora.models.layers.blocks import (
Attention,
CaptionEmbedder,
MultiHeadCrossAttention,
PatchEmbed3D,
SeqParallelAttention,
SeqParallelMultiHeadCrossAttention,
T2IFinalLayer,
TimestepEmbedder,
approx_gelu,
get_1d_sincos_pos_embed,
get_2d_sincos_pos_embed,
get_layernorm,
t2i_modulate,
)
from opensora.acceleration.checkpoint import auto_grad_checkpoint
# The implementation of ControlNet-Half architrecture
# https://github.com/lllyasviel/ControlNet/discussions/188
class ControlT2IDitBlockHalf(Module):
def __init__(self, base_block: STDiTBlock, block_index: 0) -> None:
super().__init__()
self.copied_block = deepcopy(base_block)
self.block_index = block_index
for p in self.copied_block.parameters():
p.requires_grad_(True)
self.copied_block.load_state_dict(base_block.state_dict())
self.copied_block.train()
self.hidden_size = hidden_size = base_block.hidden_size
if self.block_index == 0:
self.before_proj = Linear(hidden_size, hidden_size)
init.zeros_(self.before_proj.weight)
init.zeros_(self.before_proj.bias)
self.after_proj = Linear(hidden_size, hidden_size)
init.zeros_(self.after_proj.weight)
init.zeros_(self.after_proj.bias)
def forward(self, x, y, t0, y_lens, c, tpe):
if self.block_index == 0:
# the first block
c = self.before_proj(c)
c = self.copied_block(x + c, y, t0, y_lens, tpe)
c_skip = self.after_proj(c)
else:
# load from previous c and produce the c for skip connection
c = self.copied_block(c, y, t0, y_lens, tpe)
c_skip = self.after_proj(c)
return c, c_skip
# The implementation of ControlPixArtHalf net
class ControlPixArtHalf(Module):
# only support single res model
def __init__(self, base_model: STDiT, copy_blocks_num: int = 13) -> None:
super().__init__()
self.base_model = base_model.eval()
self.controlnet = []
self.copy_blocks_num = copy_blocks_num
self.total_blocks_num = len(base_model.blocks)
for p in self.base_model.parameters():
p.requires_grad_(False)
# Copy first copy_blocks_num block
for i in range(copy_blocks_num):
self.controlnet.append(ControlT2IDitBlockHalf(base_model.blocks[i], i))
self.controlnet = nn.ModuleList(self.controlnet)
def __getattr__(self, name: str) -> Tensor or Module:
if name in ['forward', 'forward_with_dpmsolver', 'forward_with_cfg', 'forward_c', 'load_state_dict']:
return self.__dict__[name]
elif name in ['base_model', 'controlnet']:
return super().__getattr__(name)
else:
return getattr(self.base_model, name)
def forward_c(self, c):
# self.h, self.w = c.shape[-2]//self.patch_size, c.shape[-1]//self.patch_size
# pos_embed = torch.from_numpy(get_2d_sincos_pos_embed(self.pos_embed.shape[-1], (self.h, self.w), lewei_scale=self.lewei_scale, base_size=self.base_size)).unsqueeze(0).to(c.device).to(self.dtype)
x = self.x_embedder(c) # [B, N, C]
x = rearrange(x, "B (T S) C -> B T S C", T=self.num_temporal, S=self.num_spatial)
x = x + self.pos_embed
x = rearrange(x, "B T S C -> B (T S) C")
return x if c is not None else c
# def forward(self, x, t, c, **kwargs):
# return self.base_model(x, t, c=self.forward_c(c), **kwargs)
def forward(self, x, timestep, y, mask=None, x_mask=None, c=None):
# modify the original PixArtMS forward function
if c is not None:
c = c.to(self.dtype)
c = self.forward_c(c)
"""
Forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) tensor of class labels
"""
x = x.to(self.dtype)
timestep = timestep.to(self.dtype)
y = y.to(self.dtype)
# embedding
x = self.x_embedder(x) # [B, N, C]
x = rearrange(x, "B (T S) C -> B T S C", T=self.num_temporal, S=self.num_spatial)
x = x + self.pos_embed
x = rearrange(x, "B T S C -> B (T S) C")
# shard over the sequence dim if sp is enabled
if self.enable_sequence_parallelism:
x = split_forward_gather_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="down")
t = self.t_embedder(timestep, dtype=x.dtype) # [B, C]
t0 = self.t_block(t) # [B, C]
y = self.y_embedder(y, self.training) # [B, 1, N_token, C]
if mask is not None:
if mask.shape[0] != y.shape[0]:
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
mask = mask.squeeze(1).squeeze(1)
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
else:
y_lens = [y.shape[2]] * y.shape[0]
y = y.squeeze(1).view(1, -1, x.shape[-1])
# define the first layer
# y_ori = y
tpe = self.pos_embed_temporal
x = auto_grad_checkpoint(self.base_model.blocks[0], x, y, t0, y_lens, tpe)
# define the rest layers
# update c
for index in range(1, self.copy_blocks_num + 1):
if index == 1:
c, c_skip = auto_grad_checkpoint(self.controlnet[index - 1], x, y, t0, y_lens, c, tpe)
else:
c, c_skip = auto_grad_checkpoint(self.controlnet[index - 1], x, y, t0, y_lens, c, None)
x = auto_grad_checkpoint(self.base_model.blocks[index], x + c_skip, y, t0, y_lens, None)
# update x
for index in range(self.copy_blocks_num + 1, self.total_blocks_num):
x = auto_grad_checkpoint(self.base_model.blocks[index], x, y, t0, y_lens, None)
# final process
x = self.final_layer(x, t) # [B, N, C=T_p * H_p * W_p * C_out]
x = self.unpatchify(x) # [B, C_out, T, H, W]
# cast to float32 for better accuracy
x = x.to(torch.float32)
return x
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
if all((k.startswith('base_model') or k.startswith('controlnet')) for k in state_dict.keys()):
return super().load_state_dict(state_dict, strict)
else:
new_key = {}
for k in state_dict.keys():
new_key[k] = re.sub(r"(blocks\.\d+)(.*)", r"\1.base_block\2", k)
for k, v in new_key.items():
if k != v:
print(f"replace {k} to {v}")
state_dict[v] = state_dict.pop(k)
return self.base_model.load_state_dict(state_dict, strict)
def unpatchify(self, x):
"""
Args:
x (torch.Tensor): of shape [B, N, C]
Return:
x (torch.Tensor): of shape [B, C_out, T, H, W]
"""
N_t, N_h, N_w = [self.input_size[i] // self.patch_size[i] for i in range(3)]
T_p, H_p, W_p = self.patch_size
x = rearrange(
x,
"B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)",
N_t=N_t,
N_h=N_h,
N_w=N_w,
T_p=T_p,
H_p=H_p,
W_p=W_p,
C_out=self.out_channels,
)
return x
def unpatchify_old(self, x):
c = self.out_channels
t, h, w = [self.input_size[i] // self.patch_size[i] for i in range(3)]
pt, ph, pw = self.patch_size
x = x.reshape(shape=(x.shape[0], t, h, w, pt, ph, pw, c))
x = rearrange(x, "n t h w r p q c -> n c t r h p w q")
imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
return imgs
def get_spatial_pos_embed(self, grid_size=None):
if grid_size is None:
grid_size = self.input_size[1:]
pos_embed = get_2d_sincos_pos_embed(
self.hidden_size,
(grid_size[0] // self.patch_size[1], grid_size[1] // self.patch_size[2]),
scale=self.space_scale,
)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)
return pos_embed
def get_temporal_pos_embed(self):
pos_embed = get_1d_sincos_pos_embed(
self.hidden_size,
self.input_size[0] // self.patch_size[0],
scale=self.time_scale,
)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)
return pos_embed
def freeze_not_temporal(self):
for n, p in self.named_parameters():
if "attn_temp" not in n:
p.requires_grad = False
def freeze_text(self):
for n, p in self.named_parameters():
if "cross_attn" in n:
p.requires_grad = False
def initialize_temporal(self):
for block in self.blocks:
nn.init.constant_(block.attn_temp.proj.weight, 0)
nn.init.constant_(block.attn_temp.proj.bias, 0)
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
nn.init.normal_(self.t_block[1].weight, std=0.02)
# Initialize caption embedding MLP:
nn.init.normal_(self.y_embedder.y_proj.fc1.weight, std=0.02)
nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02)
# Zero-out adaLN modulation layers in PixArt blocks:
for block in self.blocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
import re
import torch
import torch.nn as nn
from copy import deepcopy
from torch import Tensor
from torch.nn import Module, Linear, init
from typing import Any, Mapping
# import pywt
import torch.fft
import numpy as np
import sys
sys.path.append("/mnt/bn/videodataset/VSR/VSR")
from einops import rearrange
# from diffusion.model.nets import PixArtMSBlock, PixArtMS, PixArt
from opensora.models.stdit.stdit import STDiTBlock, STDiT
from opensora.models.layers.blocks import (
Attention,
CaptionEmbedder,
MultiHeadCrossAttention,
PatchEmbed3D,
SeqParallelAttention,
SeqParallelMultiHeadCrossAttention,
T2IFinalLayer,
TimestepEmbedder,
approx_gelu,
get_1d_sincos_pos_embed,
get_2d_sincos_pos_embed,
get_layernorm,
t2i_modulate,
)
from opensora.acceleration.checkpoint import auto_grad_checkpoint
import torch.nn.functional as F
from einops import rearrange
# The implementation of ControlNet-Half architrecture
# https://github.com/lllyasviel/ControlNet/discussions/188
class ControlT2IDitBlockHalf(Module):
def __init__(self, base_block: STDiTBlock, block_index: 0) -> None:
super().__init__()
self.copied_block = deepcopy(base_block)
self.block_index = block_index
for p in self.copied_block.parameters():
p.requires_grad_(True)
self.copied_block.load_state_dict(base_block.state_dict())
self.copied_block.train()
self.hidden_size = hidden_size = base_block.hidden_size
if self.block_index == 0:
self.before_proj = Linear(hidden_size, hidden_size)
init.zeros_(self.before_proj.weight)
init.zeros_(self.before_proj.bias)
self.after_proj = Linear(hidden_size, hidden_size)
init.zeros_(self.after_proj.weight)
init.zeros_(self.after_proj.bias)
def forward(self, x, y, t0, y_lens, c, tpe, hf_fea, lf_fea, temp_fea):
if self.block_index == 0:
# the first block
c = self.before_proj(c)
c = self.copied_block(x + c, y, t0, y_lens, tpe, hf_fea, lf_fea, temp_fea)
c_skip = self.after_proj(c)
else:
# load from previous c and produce the c for skip connection
c = self.copied_block(c, y, t0, y_lens, tpe, hf_fea, lf_fea, temp_fea)
c_skip = self.after_proj(c)
return c, c_skip
# The implementation of ControlPixArtHalf net
class ControlPixArtHalf(Module):
# only support single res model
def __init__(self, base_model: STDiT, copy_blocks_num: int = 13) -> None:
super().__init__()
self.base_model = base_model
self.controlnet = []
self.copy_blocks_num = copy_blocks_num
self.total_blocks_num = len(base_model.blocks)
for p in self.base_model.parameters():
p.requires_grad_(False)
# Copy first copy_blocks_num block
for i in range(copy_blocks_num):
self.controlnet.append(ControlT2IDitBlockHalf(base_model.blocks[i], i))
self.controlnet = nn.ModuleList(self.controlnet)
def __getattr__(self, name: str) -> Tensor or Module:
if name in ['forward', 'forward_with_dpmsolver', 'forward_with_cfg', 'forward_c', 'load_state_dict']:
return self.__dict__[name]
elif name in ['base_model', 'controlnet']:
return super().__getattr__(name)
else:
return getattr(self.base_model, name)
def forward_c(self, c):
### Controlnet Input ###
x = self.x_embedder(c) # [B, N, 1152]
x = rearrange(x, "B (T S) C -> B T S C", T=self.num_temporal, S=self.num_spatial)
x = x + self.pos_embed
x = rearrange(x, "B T S C -> B (T S) C")
return x
def forward_hf(self, hf_part):
### Controlnet Input ###
x = self.hf_embedder(hf_part) # [B, N, 1152]
x = rearrange(x, "B (T S) C -> B T S C", T=self.num_temporal, S=self.num_spatial)
x = x + self.pos_embed
x = rearrange(x, "B T S C -> B (T S) C")
return x
def forward_lf(self, lf_part):
### Controlnet Input ###
x = self.lf_embedder(lf_part) # [B, N, 1152]
x = rearrange(x, "B (T S) C -> B T S C", T=self.num_temporal, S=self.num_spatial)
x = x + self.pos_embed
x = rearrange(x, "B T S C -> B (T S) C")
return x
def forward(self, x, timestep, y, mask=None, c=None, lr=None):
# modify the original PixArtMS forward function
if c is not None:
c = c.to(self.dtype)
c = auto_grad_checkpoint(self.forward_c, c)
# generate spatial & temporal information
_, hf_part, lf_part = self.fdie.spatial_forward(lr)
hf_fea = auto_grad_checkpoint(self.forward_hf, hf_part)
lf_fea = auto_grad_checkpoint(self.forward_lf, lf_part)
temp_fea = self.fdie.temporal_forward(lf_fea)
"""
Forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) tensor of class labels
"""
x = x.to(self.dtype)
timestep = timestep.to(self.dtype)
y = y.to(self.dtype)
# embedding
x = self.x_embedder(x) # [B, N, C]
x = rearrange(x, "B (T S) C -> B T S C", T=self.num_temporal, S=self.num_spatial)
x = x + self.pos_embed
x = rearrange(x, "B T S C -> B (T S) C")
# shard over the sequence dim if sp is enabled
if self.enable_sequence_parallelism:
x = split_forward_gather_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="down")
t = self.t_embedder(timestep, dtype=x.dtype) # [B, C]
t0 = self.t_block(t) # [B, C]
y = self.y_embedder(y, self.training) # [B, 1, N_token, C]
if mask is not None:
if mask.shape[0] != y.shape[0]:
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
mask = mask.squeeze(1).squeeze(1)
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
else:
y_lens = [y.shape[2]] * y.shape[0]
y = y.squeeze(1).view(1, -1, x.shape[-1])
# define the first layer
# y_ori = y
tpe = self.pos_embed_temporal
x = auto_grad_checkpoint(self.base_model.blocks[0], x, y, t0, y_lens, tpe, hf_fea, lf_fea, temp_fea)
# define the rest layers
# update c
for index in range(1, self.copy_blocks_num + 1):
if index == 1:
c, c_skip = auto_grad_checkpoint(self.controlnet[index - 1], x, y, t0, y_lens, c, tpe, hf_fea, lf_fea, temp_fea)
else:
c, c_skip = auto_grad_checkpoint(self.controlnet[index - 1], x, y, t0, y_lens, c, None, hf_fea, lf_fea, temp_fea)
x = auto_grad_checkpoint(self.base_model.blocks[index], x + c_skip, y, t0, y_lens, None, hf_fea, lf_fea, temp_fea)
# update x
for index in range(self.copy_blocks_num + 1, self.total_blocks_num):
x = auto_grad_checkpoint(self.base_model.blocks[index], x, y, t0, y_lens, None, hf_fea, lf_fea, temp_fea)
# final process
x = self.final_layer(x, t) # [B, N, C=T_p * H_p * W_p * C_out]
x = self.unpatchify(x) # [B, C_out, T, H, W]
# cast to float32 for better accuracy
x = x.to(torch.float32)
return x
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
if all((k.startswith('base_model') or k.startswith('controlnet')) for k in state_dict.keys()):
return super().load_state_dict(state_dict, strict)
else:
new_key = {}
for k in state_dict.keys():
new_key[k] = re.sub(r"(blocks\.\d+)(.*)", r"\1.base_block\2", k)
for k, v in new_key.items():
if k != v:
print(f"replace {k} to {v}")
state_dict[v] = state_dict.pop(k)
return self.base_model.load_state_dict(state_dict, strict)
def unpatchify(self, x):
"""
Args:
x (torch.Tensor): of shape [B, N, C]
Return:
x (torch.Tensor): of shape [B, C_out, T, H, W]
"""
N_t, N_h, N_w = [self.input_size[i] // self.patch_size[i] for i in range(3)]
T_p, H_p, W_p = self.patch_size
x = rearrange(
x,
"B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)",
N_t=N_t,
N_h=N_h,
N_w=N_w,
T_p=T_p,
H_p=H_p,
W_p=W_p,
C_out=self.out_channels,
)
return x
def unpatchify_old(self, x):
c = self.out_channels
t, h, w = [self.input_size[i] // self.patch_size[i] for i in range(3)]
pt, ph, pw = self.patch_size
x = x.reshape(shape=(x.shape[0], t, h, w, pt, ph, pw, c))
x = rearrange(x, "n t h w r p q c -> n c t r h p w q")
imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
return imgs
def get_spatial_pos_embed(self, grid_size=None):
if grid_size is None:
grid_size = self.input_size[1:]
pos_embed = get_2d_sincos_pos_embed(
self.hidden_size,
(grid_size[0] // self.patch_size[1], grid_size[1] // self.patch_size[2]),
scale=self.space_scale,
)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)
return pos_embed
def get_temporal_pos_embed(self):
pos_embed = get_1d_sincos_pos_embed(
self.hidden_size,
self.input_size[0] // self.patch_size[0],
scale=self.time_scale,
)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)
return pos_embed
def freeze_not_temporal(self):
for n, p in self.named_parameters():
if "attn_temp" not in n:
p.requires_grad = False
def freeze_text(self):
for n, p in self.named_parameters():
if "cross_attn" in n:
p.requires_grad = False
def initialize_temporal(self):
for block in self.blocks:
nn.init.constant_(block.attn_temp.proj.weight, 0)
nn.init.constant_(block.attn_temp.proj.bias, 0)
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
nn.init.normal_(self.t_block[1].weight, std=0.02)
# Initialize caption embedding MLP:
nn.init.normal_(self.y_embedder.y_proj.fc1.weight, std=0.02)
nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02)
# Zero-out adaLN modulation layers in PixArt blocks:
for block in self.blocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
import re
import torch
import torch.nn as nn
from copy import deepcopy
from torch import Tensor
from torch.nn import Module, Linear, init
from typing import Any, Mapping
import sys
sys.path.append("/home/test/Workspace/ruixie/Open-Sora")
from einops import rearrange
# from diffusion.model.nets import PixArtMSBlock, PixArtMS, PixArt
from opensora.models.stdit.stdit import STDiTBlock, STDiT
from opensora.models.layers.blocks import (
Attention,
CaptionEmbedder,
MultiHeadCrossAttention,
PatchEmbed3D,
SeqParallelAttention,
SeqParallelMultiHeadCrossAttention,
T2IFinalLayer,
TimestepEmbedder,
approx_gelu,
get_1d_sincos_pos_embed,
get_2d_sincos_pos_embed,
get_layernorm,
t2i_modulate,
)
from opensora.acceleration.checkpoint import auto_grad_checkpoint
# The implementation of ControlNet-Half architrecture
# https://github.com/lllyasviel/ControlNet/discussions/188
class ControlT2IDitBlockHalf(Module):
def __init__(self, base_block: STDiTBlock, block_index: 0) -> None:
super().__init__()
self.copied_block = deepcopy(base_block)
self.block_index = block_index
for p in self.copied_block.parameters():
p.requires_grad_(True)
self.copied_block.load_state_dict(base_block.state_dict())
self.copied_block.train()
self.hidden_size = hidden_size = base_block.hidden_size
if self.block_index == 0:
self.before_proj = Linear(hidden_size, hidden_size)
init.zeros_(self.before_proj.weight)
init.zeros_(self.before_proj.bias)
self.after_proj = Linear(hidden_size, hidden_size)
init.zeros_(self.after_proj.weight)
init.zeros_(self.after_proj.bias)
def forward(self, x, y, t0, t_y, t0_tmep, t_y_tmep, mask, c, tpe):
if self.block_index == 0:
# the first block
c = self.before_proj(c)
c, y = self.copied_block(x + c, y, t0, t_y, t0_tmep, t_y_tmep, mask, tpe)
c_skip = self.after_proj(c)
else:
# load from previous c and produce the c for skip connection
c, y = self.copied_block(c, y, t0, t_y, t0_tmep, t_y_tmep, mask, tpe)
c_skip = self.after_proj(c)
return c, c_skip, y
# The implementation of ControlPixArtHalf net
class ControlPixArtHalf(Module):
# only support single res model
def __init__(self, base_model: STDiT, copy_blocks_num: int = 13) -> None:
super().__init__()
self.base_model = base_model.eval()
self.controlnet = []
self.copy_blocks_num = copy_blocks_num
self.total_blocks_num = len(base_model.blocks)
for p in self.base_model.parameters():
p.requires_grad_(False)
# Copy first copy_blocks_num block
for i in range(copy_blocks_num):
self.controlnet.append(ControlT2IDitBlockHalf(base_model.blocks[i], i))
self.controlnet = nn.ModuleList(self.controlnet)
def __getattr__(self, name: str) -> Tensor or Module:
if name in ['forward', 'forward_with_dpmsolver', 'forward_with_cfg', 'forward_c', 'load_state_dict']:
return self.__dict__[name]
elif name in ['base_model', 'controlnet']:
return super().__getattr__(name)
else:
return getattr(self.base_model, name)
def forward_c(self, c):
# self.h, self.w = c.shape[-2]//self.patch_size, c.shape[-1]//self.patch_size
# pos_embed = torch.from_numpy(get_2d_sincos_pos_embed(self.pos_embed.shape[-1], (self.h, self.w), lewei_scale=self.lewei_scale, base_size=self.base_size)).unsqueeze(0).to(c.device).to(self.dtype)
x = self.x_embedder(c) # [B, N, C]
x = rearrange(x, "B (T S) C -> B T S C", T=self.num_temporal, S=self.num_spatial)
x = x + self.pos_embed
x = rearrange(x, "B T S C -> B (T S) C")
return x if c is not None else c
# def forward(self, x, t, c, **kwargs):
# return self.base_model(x, t, c=self.forward_c(c), **kwargs)
def forward(self, x, timestep, y, mask=None, x_mask=None, c=None):
# modify the original PixArtMS forward function
if c is not None:
# print("Process condition")
c = c.to(self.dtype)
c = self.forward_c(c)
"""
Forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) tensor of class labels
"""
x = x.to(self.dtype)
timestep = timestep.to(self.dtype)
y = y.to(self.dtype)
# embedding
x = self.x_embedder(x) # [B, N, C]
x = rearrange(x, "B (T S) C -> B T S C", T=self.num_temporal, S=self.num_spatial)
x = x + self.pos_embed
x = rearrange(x, "B T S C -> B (T S) C")
# shard over the sequence dim if sp is enabled
if self.enable_sequence_parallelism:
x = split_forward_gather_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="down")
# t = self.t_embedder(timestep, dtype=x.dtype) # [B, C]
# t_spc_mlp = self.t_block(t) # [B, 6*C]
# t_tmp_mlp = self.t_block_temp(t) # [B, 3*C]
t = self.t_embedder(timestep, dtype=x.dtype) # [B, C]
t0 = self.t_block(t) # [B, C]
t_y = self.t_block_y(t)
t0_tmep = self.t_block_temp(t) # [B, C]
t_y_tmep = self.t_block_y_temp(t)
y = self.y_embedder(y, self.training) # [B, 1, N_token, C]
if mask is not None:
if mask.shape[0] != y.shape[0]:
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
mask = mask.squeeze(1).squeeze(1)
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
else:
y_lens = [y.shape[2]] * y.shape[0]
y = y.squeeze(1).view(1, -1, x.shape[-1])
# define the first layer
y_ori = y
tpe = self.pos_embed_temporal
x, y_x = auto_grad_checkpoint(self.base_model.blocks[0], x, y_ori, t0, t_y, t0_tmep, t_y_tmep, y_lens, tpe)
# define the rest layers
# update c
for index in range(1, self.copy_blocks_num + 1):
if index == 1:
c, c_skip, y_c = auto_grad_checkpoint(self.controlnet[index - 1], x, y_ori, t0, t_y, t0_tmep, t_y_tmep, y_lens, c, tpe)
else:
c, c_skip, y_c = auto_grad_checkpoint(self.controlnet[index - 1], x, y_c, t0, t_y, t0_tmep, t_y_tmep, y_lens, c, None)
x, y_x = auto_grad_checkpoint(self.base_model.blocks[index], x + c_skip, y_x, t0, t_y, t0_tmep, t_y_tmep, y_lens, None)
# update x
for index in range(self.copy_blocks_num + 1, self.total_blocks_num):
x, y_x = auto_grad_checkpoint(self.base_model.blocks[index], x, y_x, t0, t_y, t0_tmep, t_y_tmep, y_lens, None)
# final process
x = self.final_layer(x, t) # [B, N, C=T_p * H_p * W_p * C_out]
x = self.unpatchify(x) # [B, C_out, T, H, W]
# cast to float32 for better accuracy
x = x.to(torch.float32)
return x
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
if all((k.startswith('base_model') or k.startswith('controlnet')) for k in state_dict.keys()):
return super().load_state_dict(state_dict, strict)
else:
new_key = {}
for k in state_dict.keys():
new_key[k] = re.sub(r"(blocks\.\d+)(.*)", r"\1.base_block\2", k)
for k, v in new_key.items():
if k != v:
print(f"replace {k} to {v}")
state_dict[v] = state_dict.pop(k)
return self.base_model.load_state_dict(state_dict, strict)
def unpatchify(self, x):
"""
Args:
x (torch.Tensor): of shape [B, N, C]
Return:
x (torch.Tensor): of shape [B, C_out, T, H, W]
"""
N_t, N_h, N_w = [self.input_size[i] // self.patch_size[i] for i in range(3)]
T_p, H_p, W_p = self.patch_size
x = rearrange(
x,
"B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)",
N_t=N_t,
N_h=N_h,
N_w=N_w,
T_p=T_p,
H_p=H_p,
W_p=W_p,
C_out=self.out_channels,
)
return x
def unpatchify_old(self, x):
c = self.out_channels
t, h, w = [self.input_size[i] // self.patch_size[i] for i in range(3)]
pt, ph, pw = self.patch_size
x = x.reshape(shape=(x.shape[0], t, h, w, pt, ph, pw, c))
x = rearrange(x, "n t h w r p q c -> n c t r h p w q")
imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
return imgs
def get_spatial_pos_embed(self, grid_size=None):
if grid_size is None:
grid_size = self.input_size[1:]
pos_embed = get_2d_sincos_pos_embed(
self.hidden_size,
(grid_size[0] // self.patch_size[1], grid_size[1] // self.patch_size[2]),
scale=self.space_scale,
)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)
return pos_embed
def get_temporal_pos_embed(self):
pos_embed = get_1d_sincos_pos_embed(
self.hidden_size,
self.input_size[0] // self.patch_size[0],
scale=self.time_scale,
)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)
return pos_embed
def freeze_not_temporal(self):
for n, p in self.named_parameters():
if "attn_temp" not in n:
p.requires_grad = False
def freeze_text(self):
for n, p in self.named_parameters():
if "cross_attn" in n:
p.requires_grad = False
def initialize_temporal(self):
for block in self.blocks:
nn.init.constant_(block.attn_temp.proj.weight, 0)
nn.init.constant_(block.attn_temp.proj.bias, 0)
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
nn.init.normal_(self.t_block[1].weight, std=0.02)
# Initialize caption embedding MLP:
nn.init.normal_(self.y_embedder.y_proj.fc1.weight, std=0.02)
nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02)
# Zero-out adaLN modulation layers in PixArt blocks:
for block in self.blocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
import re
import torch
import torch.nn as nn
from copy import deepcopy
from torch import Tensor
from torch.nn import Module, Linear, init
from typing import Any, Mapping
import sys
sys.path.append("/home/test/Workspace/ruixie/Open-Sora")
from einops import rearrange
# from diffusion.model.nets import PixArtMSBlock, PixArtMS, PixArt
from opensora.models.stdit.stdit import STDiTBlock, STDiT
from opensora.models.layers.blocks import (
Attention,
CaptionEmbedder,
MultiHeadCrossAttention,
PatchEmbed3D,
SeqParallelAttention,
SeqParallelMultiHeadCrossAttention,
T2IFinalLayer,
TimestepEmbedder,
approx_gelu,
get_1d_sincos_pos_embed,
get_2d_sincos_pos_embed,
get_layernorm,
t2i_modulate,
)
from opensora.acceleration.checkpoint import auto_grad_checkpoint
# The implementation of ControlNet-Half architrecture
# https://github.com/lllyasviel/ControlNet/discussions/188
class ControlT2IDitBlockHalf(Module):
def __init__(self, base_block: STDiTBlock, block_index: 0) -> None:
super().__init__()
self.copied_block = deepcopy(base_block)
self.block_index = block_index
for p in self.copied_block.parameters():
p.requires_grad_(True)
self.copied_block.load_state_dict(base_block.state_dict())
self.copied_block.train()
self.hidden_size = hidden_size = base_block.hidden_size
if self.block_index == 0:
self.before_proj = Linear(hidden_size, hidden_size)
init.zeros_(self.before_proj.weight)
init.zeros_(self.before_proj.bias)
self.after_proj = Linear(hidden_size, hidden_size)
init.zeros_(self.after_proj.weight)
init.zeros_(self.after_proj.bias)
def forward(self, x, y, t0, t0_tmep, y_lens, c, tpe):
if self.block_index == 0:
# the first block
c = self.before_proj(c)
c = self.copied_block(x + c, y, t0, t0_tmep, y_lens, tpe)
c_skip = self.after_proj(c)
else:
# load from previous c and produce the c for skip connection
c = self.copied_block(c, y, t0, t0_tmep, y_lens, tpe)
c_skip = self.after_proj(c)
return c, c_skip
# The implementation of ControlPixArtHalf net
class ControlPixArtHalf(Module):
# only support single res model
def __init__(self, base_model: STDiT, copy_blocks_num: int = 13) -> None:
super().__init__()
self.base_model = base_model.eval()
self.controlnet = []
self.copy_blocks_num = copy_blocks_num
self.total_blocks_num = len(base_model.blocks)
for p in self.base_model.parameters():
p.requires_grad_(False)
# Copy first copy_blocks_num block
for i in range(copy_blocks_num):
self.controlnet.append(ControlT2IDitBlockHalf(base_model.blocks[i], i))
self.controlnet = nn.ModuleList(self.controlnet)
def __getattr__(self, name: str) -> Tensor or Module:
if name in ['forward', 'forward_with_dpmsolver', 'forward_with_cfg', 'forward_c', 'load_state_dict']:
return self.__dict__[name]
elif name in ['base_model', 'controlnet']:
return super().__getattr__(name)
else:
return getattr(self.base_model, name)
def forward_c(self, c):
# self.h, self.w = c.shape[-2]//self.patch_size, c.shape[-1]//self.patch_size
# pos_embed = torch.from_numpy(get_2d_sincos_pos_embed(self.pos_embed.shape[-1], (self.h, self.w), lewei_scale=self.lewei_scale, base_size=self.base_size)).unsqueeze(0).to(c.device).to(self.dtype)
x = self.x_embedder(c) # [B, N, C]
x = rearrange(x, "B (T S) C -> B T S C", T=self.num_temporal, S=self.num_spatial)
x = x + self.pos_embed
x = rearrange(x, "B T S C -> B (T S) C")
return x if c is not None else c
# def forward(self, x, t, c, **kwargs):
# return self.base_model(x, t, c=self.forward_c(c), **kwargs)
def forward(self, x, timestep, y, mask=None, x_mask=None, c=None):
# modify the original PixArtMS forward function
if c is not None:
c = c.to(self.dtype)
c = self.forward_c(c)
"""
Forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) tensor of class labels
"""
x = x.to(self.dtype)
timestep = timestep.to(self.dtype)
y = y.to(self.dtype)
# embedding
x = self.x_embedder(x) # [B, N, C]
x = rearrange(x, "B (T S) C -> B T S C", T=self.num_temporal, S=self.num_spatial)
x = x + self.pos_embed
x = rearrange(x, "B T S C -> B (T S) C")
# shard over the sequence dim if sp is enabled
if self.enable_sequence_parallelism:
x = split_forward_gather_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="down")
t = self.t_embedder(timestep, dtype=x.dtype) # [B, C]
t0 = self.t_block(t) # [B, C]
t0_temp = self.t_block_temp(t) # [B, C]
y = self.y_embedder(y, self.training) # [B, 1, N_token, C]
if mask is not None:
if mask.shape[0] != y.shape[0]:
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
mask = mask.squeeze(1).squeeze(1)
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
else:
y_lens = [y.shape[2]] * y.shape[0]
y = y.squeeze(1).view(1, -1, x.shape[-1])
# define the first layer
# y_ori = y
tpe = self.pos_embed_temporal
x = auto_grad_checkpoint(self.base_model.blocks[0], x, y, t0, t0_temp, y_lens, tpe)
# define the rest layers
# update c
for index in range(1, self.copy_blocks_num + 1):
if index == 1:
c, c_skip = auto_grad_checkpoint(self.controlnet[index - 1], x, y, t0, t0_temp, y_lens, c, tpe)
else:
c, c_skip = auto_grad_checkpoint(self.controlnet[index - 1], x, y, t0, t0_temp, y_lens, c, None)
x = auto_grad_checkpoint(self.base_model.blocks[index], x + c_skip, y, t0, t0_temp, y_lens, None)
# update x
for index in range(self.copy_blocks_num + 1, self.total_blocks_num):
x = auto_grad_checkpoint(self.base_model.blocks[index], x, y, t0, t0_temp, y_lens, None)
# final process
x = self.final_layer(x, t) # [B, N, C=T_p * H_p * W_p * C_out]
x = self.unpatchify(x) # [B, C_out, T, H, W]
# cast to float32 for better accuracy
x = x.to(torch.float32)
return x
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
if all((k.startswith('base_model') or k.startswith('controlnet')) for k in state_dict.keys()):
return super().load_state_dict(state_dict, strict)
else:
new_key = {}
for k in state_dict.keys():
new_key[k] = re.sub(r"(blocks\.\d+)(.*)", r"\1.base_block\2", k)
for k, v in new_key.items():
if k != v:
print(f"replace {k} to {v}")
state_dict[v] = state_dict.pop(k)
return self.base_model.load_state_dict(state_dict, strict)
def unpatchify(self, x):
"""
Args:
x (torch.Tensor): of shape [B, N, C]
Return:
x (torch.Tensor): of shape [B, C_out, T, H, W]
"""
N_t, N_h, N_w = [self.input_size[i] // self.patch_size[i] for i in range(3)]
T_p, H_p, W_p = self.patch_size
x = rearrange(
x,
"B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)",
N_t=N_t,
N_h=N_h,
N_w=N_w,
T_p=T_p,
H_p=H_p,
W_p=W_p,
C_out=self.out_channels,
)
return x
def unpatchify_old(self, x):
c = self.out_channels
t, h, w = [self.input_size[i] // self.patch_size[i] for i in range(3)]
pt, ph, pw = self.patch_size
x = x.reshape(shape=(x.shape[0], t, h, w, pt, ph, pw, c))
x = rearrange(x, "n t h w r p q c -> n c t r h p w q")
imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
return imgs
def get_spatial_pos_embed(self, grid_size=None):
if grid_size is None:
grid_size = self.input_size[1:]
pos_embed = get_2d_sincos_pos_embed(
self.hidden_size,
(grid_size[0] // self.patch_size[1], grid_size[1] // self.patch_size[2]),
scale=self.space_scale,
)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)
return pos_embed
def get_temporal_pos_embed(self):
pos_embed = get_1d_sincos_pos_embed(
self.hidden_size,
self.input_size[0] // self.patch_size[0],
scale=self.time_scale,
)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)
return pos_embed
def freeze_not_temporal(self):
for n, p in self.named_parameters():
if "attn_temp" not in n:
p.requires_grad = False
def freeze_text(self):
for n, p in self.named_parameters():
if "cross_attn" in n:
p.requires_grad = False
def initialize_temporal(self):
for block in self.blocks:
nn.init.constant_(block.attn_temp.proj.weight, 0)
nn.init.constant_(block.attn_temp.proj.bias, 0)
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
nn.init.normal_(self.t_block[1].weight, std=0.02)
# Initialize caption embedding MLP:
nn.init.normal_(self.y_embedder.y_proj.fc1.weight, std=0.02)
nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02)
# Zero-out adaLN modulation layers in PixArt blocks:
for block in self.blocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from einops import rearrange
from timm.models.layers import DropPath
from timm.models.vision_transformer import Mlp
from opensora.acceleration.checkpoint import auto_grad_checkpoint
from opensora.acceleration.communications import gather_forward_split_backward, split_forward_gather_backward
from opensora.acceleration.parallel_states import get_sequence_parallel_group
from opensora.models.layers.blocks import (
Attention,
CaptionEmbedder,
MultiHeadCrossAttention,
PatchEmbed3D,
SeqParallelAttention,
SeqParallelMultiHeadCrossAttention,
T2IFinalLayer,
TimestepEmbedder,
approx_gelu,
get_1d_sincos_pos_embed,
get_2d_sincos_pos_embed,
get_layernorm,
t2i_modulate,
SpatialFrequencyBlcok,
TemporalFrequencyBlock,
Encoder_3D,
)
from opensora.registry import MODELS
from opensora.utils.ckpt_utils import load_checkpoint
from opensora.models.vsr.sfr_lftg import SpatialFeatureRefiner, LFTemporalGuider
from opensora.models.vsr.fdie_arch import FrequencyDecoupledInfoExtractor
class STDiTBlock(nn.Module):
def __init__(
self,
hidden_size,
num_heads,
d_s=None,
d_t=None,
mlp_ratio=4.0,
drop_path=0.0,
enable_flashattn=False,
enable_layernorm_kernel=False,
enable_sequence_parallelism=False,
):
super().__init__()
self.hidden_size = hidden_size
self.enable_flashattn = enable_flashattn
self._enable_sequence_parallelism = enable_sequence_parallelism
if enable_sequence_parallelism:
self.attn_cls = SeqParallelAttention
self.mha_cls = SeqParallelMultiHeadCrossAttention
else: # here
self.attn_cls = Attention
self.mha_cls = MultiHeadCrossAttention
self.norm1 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)
self.attn = self.attn_cls(
hidden_size,
num_heads=num_heads,
qkv_bias=True,
enable_flashattn=enable_flashattn,
)
self.cross_attn = self.mha_cls(hidden_size, num_heads)
self.norm2 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)
self.mlp = Mlp(
in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5)
# temporal attention
self.d_s = d_s
self.d_t = d_t
if self._enable_sequence_parallelism:
sp_size = dist.get_world_size(get_sequence_parallel_group())
# make sure d_t is divisible by sp_size
assert d_t % sp_size == 0
self.d_t = d_t // sp_size
self.attn_temp = self.attn_cls(
hidden_size,
num_heads=num_heads,
qkv_bias=True,
enable_flashattn=self.enable_flashattn,
)
# frequency block for spatial & temporal (New)
self.sfr = SpatialFeatureRefiner(hidden_channels=hidden_size)
self.lftg = LFTemporalGuider(d_model=hidden_size, num_heads=num_heads)
def forward(self, x, y, t, mask=None, tpe=None, hf_fea=None, lf_fea=None, temp_fea=None):
B, _, _ = x.shape
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.scale_shift_table[None] + t.reshape(B, 6, -1)
).chunk(6, dim=1)
x_m = t2i_modulate(self.norm1(x), shift_msa, scale_msa)
# spatial feature refiner (New)
x_s = self.sfr(hf_fea, lf_fea, x_m)
# spatial branch
x_s = rearrange(x_s, "B (T S) C -> (B T) S C", T=self.d_t, S=self.d_s)
x_s = self.attn(x_s)
x_s = rearrange(x_s, "(B T) S C -> B (T S) C", T=self.d_t, S=self.d_s)
x = x + self.drop_path(gate_msa * x_s)
# LF temporal guider (New)
x_t = rearrange(x, "B (T S) C -> (B S) T C", T=self.d_t, S=self.d_s)
temp_fea = rearrange(temp_fea, "B (T S) C -> (B S) T C", T=self.d_t, S=self.d_s) # added tpe in the fdie
if tpe is not None:
x_t = x_t + tpe
x_t = self.lftg(x_t, temp_fea)
# temporal branch
x_t = self.attn_temp(x_t)
x_t = rearrange(x_t, "(B S) T C -> B (T S) C", T=self.d_t, S=self.d_s)
x = x + self.drop_path(gate_msa * x_t)
# cross attn
x = x + self.cross_attn(x, y, mask)
# mlp
x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
return x
@MODELS.register_module()
class STDiT_freq(nn.Module):
def __init__(
self,
input_size=(1, 32, 32),
in_channels=4,
patch_size=(1, 2, 2),
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4.0,
class_dropout_prob=0.1,
pred_sigma=True,
drop_path=0.0,
no_temporal_pos_emb=False,
caption_channels=4096,
model_max_length=120,
dtype=torch.float32,
space_scale=1.0,
time_scale=1.0,
freeze=None,
enable_flashattn=False,
enable_layernorm_kernel=False,
enable_sequence_parallelism=False,
):
super().__init__()
self.pred_sigma = pred_sigma
self.in_channels = in_channels
self.out_channels = in_channels * 2 if pred_sigma else in_channels
self.hidden_size = hidden_size
self.patch_size = patch_size
self.input_size = input_size
num_patches = np.prod([input_size[i] // patch_size[i] for i in range(3)])
self.num_patches = num_patches
self.num_temporal = input_size[0] // patch_size[0]
self.num_spatial = num_patches // self.num_temporal
self.num_heads = num_heads
self.dtype = dtype
self.no_temporal_pos_emb = no_temporal_pos_emb
self.depth = depth
self.mlp_ratio = mlp_ratio
self.enable_flashattn = enable_flashattn
self.enable_layernorm_kernel = enable_layernorm_kernel
self.space_scale = space_scale
self.time_scale = time_scale
self.register_buffer("pos_embed", self.get_spatial_pos_embed())
self.register_buffer("pos_embed_temporal", self.get_temporal_pos_embed())
self.x_embedder = PatchEmbed3D(patch_size, in_channels, hidden_size)
self.t_embedder = TimestepEmbedder(hidden_size)
self.t_block = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
self.y_embedder = CaptionEmbedder(
in_channels=caption_channels,
hidden_size=hidden_size,
uncond_prob=class_dropout_prob,
act_layer=approx_gelu,
token_num=model_max_length,
)
drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)]
self.blocks = nn.ModuleList(
[
STDiTBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=self.mlp_ratio,
drop_path=drop_path[i],
enable_flashattn=self.enable_flashattn,
enable_layernorm_kernel=self.enable_layernorm_kernel,
enable_sequence_parallelism=enable_sequence_parallelism,
d_t=self.num_temporal,
d_s=self.num_spatial,
)
for i in range(self.depth)
]
)
self.final_layer = T2IFinalLayer(hidden_size, np.prod(self.patch_size), self.out_channels)
# frequency-decoupled information extractor
self.fdie = FrequencyDecoupledInfoExtractor(in_channels=3, hidden_channels=64)
# high-frequency & low-frequency embedder
self.hf_embedder = PatchEmbed3D(patch_size=(1, 16, 16), in_chans=3, embed_dim=hidden_size)
self.lf_embedder = PatchEmbed3D(patch_size=(1, 16, 16), in_chans=3, embed_dim=hidden_size)
# init model
self.initialize_weights()
self.initialize_temporal()
# sequence parallel related configs
self.enable_sequence_parallelism = enable_sequence_parallelism
if enable_sequence_parallelism:
self.sp_rank = dist.get_rank(get_sequence_parallel_group())
else:
self.sp_rank = None
def forward(self, x, timestep, y, mask=None):
"""
Forward pass of STDiT.
Args:
x (torch.Tensor): latent representation of video; of shape [B, C, T, H, W]
timestep (torch.Tensor): diffusion time steps; of shape [B]
y (torch.Tensor): representation of prompts; of shape [B, 1, N_token, C]
mask (torch.Tensor): mask for selecting prompt tokens; of shape [B, N_token]
Returns:
x (torch.Tensor): output latent representation; of shape [B, C, T, H, W]
"""
x = x.to(self.dtype)
timestep = timestep.to(self.dtype)
y = y.to(self.dtype)
# embedding
x = self.x_embedder(x) # [B, N, C]
x = rearrange(x, "B (T S) C -> B T S C", T=self.num_temporal, S=self.num_spatial)
x = x + self.pos_embed
x = rearrange(x, "B T S C -> B (T S) C")
# shard over the sequence dim if sp is enabled
if self.enable_sequence_parallelism:
x = split_forward_gather_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="down")
t = self.t_embedder(timestep, dtype=x.dtype) # [B, C]
t0 = self.t_block(t) # [B, C]
y = self.y_embedder(y, self.training) # [B, 1, N_token, C]
if mask is not None:
if mask.shape[0] != y.shape[0]:
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
mask = mask.squeeze(1).squeeze(1)
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
else:
y_lens = [y.shape[2]] * y.shape[0]
y = y.squeeze(1).view(1, -1, x.shape[-1])
# blocks
for i, block in enumerate(self.blocks):
if i == 0:
if self.enable_sequence_parallelism:
tpe = torch.chunk(
self.pos_embed_temporal, dist.get_world_size(get_sequence_parallel_group()), dim=1
)[self.sp_rank].contiguous()
else:
tpe = self.pos_embed_temporal
else:
tpe = None
x = auto_grad_checkpoint(block, x, y, t0, y_lens, tpe)
if self.enable_sequence_parallelism:
x = gather_forward_split_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="up")
# x.shape: [B, N, C]
# final process
x = self.final_layer(x, t) # [B, N, C=T_p * H_p * W_p * C_out]
x = self.unpatchify(x) # [B, C_out, T, H, W]
# cast to float32 for better accuracy
x = x.to(torch.float32)
return x
def unpatchify(self, x):
"""
Args:
x (torch.Tensor): of shape [B, N, C]
Return:
x (torch.Tensor): of shape [B, C_out, T, H, W]
"""
N_t, N_h, N_w = [self.input_size[i] // self.patch_size[i] for i in range(3)]
T_p, H_p, W_p = self.patch_size
x = rearrange(
x,
"B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)",
N_t=N_t,
N_h=N_h,
N_w=N_w,
T_p=T_p,
H_p=H_p,
W_p=W_p,
C_out=self.out_channels,
)
return x
def unpatchify_old(self, x):
c = self.out_channels
t, h, w = [self.input_size[i] // self.patch_size[i] for i in range(3)]
pt, ph, pw = self.patch_size
x = x.reshape(shape=(x.shape[0], t, h, w, pt, ph, pw, c))
x = rearrange(x, "n t h w r p q c -> n c t r h p w q")
imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
return imgs
def get_spatial_pos_embed(self, grid_size=None):
if grid_size is None:
grid_size = self.input_size[1:]
pos_embed = get_2d_sincos_pos_embed(
self.hidden_size,
(grid_size[0] // self.patch_size[1], grid_size[1] // self.patch_size[2]),
scale=self.space_scale,
)
# 确保 pos_embed 是 NumPy ndarray 类型,并进行转换
pos_embed = np.array(pos_embed) # 确保 pos_embed 是 NumPy ndarray 类型
pos_embed = torch.tensor(pos_embed, dtype=torch.float32).unsqueeze(0).requires_grad_(False)
return pos_embed
def get_temporal_pos_embed(self):
pos_embed = get_1d_sincos_pos_embed(
self.hidden_size,
self.input_size[0] // self.patch_size[0],
scale=self.time_scale,
)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)
return pos_embed
def freeze_not_temporal(self):
for n, p in self.named_parameters():
if "attn_temp" not in n:
p.requires_grad = False
def freeze_text(self):
for n, p in self.named_parameters():
if "cross_attn" in n:
p.requires_grad = False
def initialize_temporal(self):
for block in self.blocks:
nn.init.constant_(block.attn_temp.proj.weight, 0)
nn.init.constant_(block.attn_temp.proj.bias, 0)
# nn.init.constant_(block.temporal_freq.proj.weight, 0)
# nn.init.constant_(block.temporal_freq.proj.bias, 0)
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
e = self.lf_embedder.proj.weight.data
nn.init.xavier_uniform_(e.view([e.shape[0], -1]))
r = self.hf_embedder.proj.weight.data
nn.init.xavier_uniform_(r.view([r.shape[0], -1]))
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
nn.init.normal_(self.t_block[1].weight, std=0.02)
# Initialize caption embedding MLP:
nn.init.normal_(self.y_embedder.y_proj.fc1.weight, std=0.02)
nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02)
# Zero-out adaLN modulation layers in PixArt blocks:
for block in self.blocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
@MODELS.register_module("STDiT-freq-XL/2")
def STDiT_freq_XL_2(from_pretrained=None, **kwargs):
model = STDiT_freq(depth=28, hidden_size=1152, patch_size=(1, 2, 2), num_heads=16, **kwargs)
if from_pretrained is not None:
load_checkpoint(model, from_pretrained)
return model
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from einops import rearrange
from timm.models.layers import DropPath
from timm.models.vision_transformer import Mlp
from opensora.acceleration.checkpoint import auto_grad_checkpoint
from opensora.acceleration.communications import gather_forward_split_backward, split_forward_gather_backward
from opensora.acceleration.parallel_states import get_sequence_parallel_group
from opensora.models.layers.blocks import (
Attention,
Attention_QKNorm_RoPE,
MaskedSelfAttention,
CaptionEmbedder,
MultiHeadCrossAttention,
MaskedMultiHeadCrossAttention,
PatchEmbed3D,
SeqParallelAttention,
SeqParallelMultiHeadCrossAttention,
T2IFinalLayer,
TimestepEmbedder,
approx_gelu,
get_1d_sincos_pos_embed,
get_2d_sincos_pos_embed,
get_layernorm,
t2i_modulate,
)
from opensora.registry import MODELS
from opensora.utils.ckpt_utils import load_checkpoint
# import ipdb
from opensora.models.layers.timm_uvit import trunc_normal_
class STDiTBlock(nn.Module):
def __init__(
self,
hidden_size,
num_heads,
d_s=None,
d_t=None,
mlp_ratio=4.0,
drop_path=0.0,
enable_flashattn=False,
enable_layernorm_kernel=False,
enable_sequence_parallelism=False,
qk_norm=False,
):
super().__init__()
self.hidden_size = hidden_size
self.enable_flashattn = enable_flashattn
self._enable_sequence_parallelism = enable_sequence_parallelism
if enable_sequence_parallelism:
self.attn_cls = SeqParallelAttention
self.mha_cls = SeqParallelMultiHeadCrossAttention
else: # here
self.self_masked_attn = MaskedSelfAttention
self.attn_cls = Attention_QKNorm_RoPE
self.mha_cls = MaskedMultiHeadCrossAttention
self.norm1 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)
self.norm1_y = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)
self.norm2 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)
self.norm2_y = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)
self.attn = self.self_masked_attn(
hidden_size,
num_heads=num_heads,
qkv_bias=True,
enable_flashattn=enable_flashattn,
qk_norm=qk_norm,
)
self.cross_attn = self.mha_cls(hidden_size, num_heads)
self.norm3 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)
self.norm3_y = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)
self.mlp = Mlp(
in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0
)
self.mlp_y = Mlp(
in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5)
self.scale_shift_table_y = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5)
self.scale_shift_table_temp = nn.Parameter(torch.randn(3, hidden_size) / hidden_size**0.5)
self.scale_shift_table_y_temp = nn.Parameter(torch.randn(3, hidden_size) / hidden_size**0.5)
# temporal attention
self.d_s = d_s
self.d_t = d_t
if self._enable_sequence_parallelism:
sp_size = dist.get_world_size(get_sequence_parallel_group())
# make sure d_t is divisible by sp_size
assert d_t % sp_size == 0
self.d_t = d_t // sp_size
self.attn_temp = self.attn_cls(
hidden_size,
num_heads=num_heads,
qkv_bias=True,
enable_flashattn=self.enable_flashattn,
qk_norm=qk_norm,
)
def forward(self, x, y, t, t_y, t_tmep, t_y_tmep, mask=None, tpe=None):
B, N, C = x.shape
L = y.shape[2] # y: B T L C, mask: B T L
x = rearrange(x, "B (T S) C -> B T S C", T=self.d_t, S=self.d_s)
x_mask = torch.ones(x.shape[:3], device=x.device, dtype=x.dtype) # B T S
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.scale_shift_table[None] + t.reshape(B, 6, -1)
).chunk(6, dim=1)
shift_msa_y, scale_msa_y, gate_msa_y, shift_mlp_y, scale_mlp_y, gate_mlp_y = (
self.scale_shift_table_y[None] + t_y.reshape(B, 6, -1)
).chunk(6, dim=1)
shift_msa_temp, scale_msa_temp, gate_msa_temp = (
self.scale_shift_table_temp[None] + t_tmep.reshape(B, 3, -1)
).chunk(3, dim=1)
shift_msa_y_temp, scale_msa_y_temp, gate_msa_y_temp = (
self.scale_shift_table_y_temp[None] + t_y_tmep.reshape(B, 3, -1)
).chunk(3, dim=1)
x = rearrange(x, "B T S C -> B (T S) C")
y = rearrange(y, "B T L C -> B (T L) C")
x_m = t2i_modulate(self.norm1(x), shift_msa, scale_msa)
y_m = t2i_modulate(self.norm1_y(y), shift_msa_y, scale_msa_y)
x_m = rearrange(x_m, "B (T S) C -> B T S C", T=self.d_t, S=self.d_s)
y_m = rearrange(y_m, "B (T L) C -> B T L C", T=self.d_t, L=L)
xy_m = torch.cat([x_m, y_m], dim=2)
xy_mask = torch.cat([x_mask, mask], dim=2)
xy_mask = rearrange(xy_mask, "B T N -> (B T) N")
# spatial branch
xy_s = rearrange(xy_m, "B T N C -> (B T) N C")
xy_s = self.attn(xy_s, xy_mask)
xy_s = rearrange(xy_s, "(B T) N C -> B T N C", B=B, T=self.d_t)
x_s = xy_s[:, :, :self.d_s, :]
y_s = xy_s[:, :, self.d_s:, :]
x_s = rearrange(x_s, "B T S C -> B (T S) C")
y_s = rearrange(y_s, "B T L C -> B (T L) C")
x = x + self.drop_path(gate_msa * x_s)
y = y + self.drop_path(gate_msa_y * y_s)
x_t = t2i_modulate(self.norm2(x), shift_msa_temp, scale_msa_temp)
y_t = t2i_modulate(self.norm2_y(y), shift_msa_y_temp, scale_msa_y_temp)
x_t = rearrange(x_t, "B (T S) C -> B T S C", T=self.d_t, S=self.d_s)
y_t = rearrange(y_t, "B (T L) C -> B T L C", T=self.d_t, L=L)
xy_t = torch.cat([x_t, y_t], dim=2)
# temporal branch
xy_t = rearrange(xy_t, "B T N C -> (B N) T C")
if tpe is not None:
xy_t = xy_t + tpe
xy_t = self.attn_temp(xy_t)
xy_t = rearrange(xy_t, "(B N) T C -> B T N C", B=B, N=self.d_s+L)
x_t = xy_t[:, :, :self.d_s, :]
y_t = xy_t[:, :, self.d_s:, :]
x_t = rearrange(x_t, "B T S C -> B (T S) C")
y_t = rearrange(y_t, "B T L C -> B (T L) C")
x = x + self.drop_path(gate_msa_temp * x_t)
y = y + self.drop_path(gate_msa_y_temp * y_t)
x = rearrange(x, "B (T S) C -> (B T) S C", T=self.d_t, S=self.d_s)
y = rearrange(y, "B (T L) C -> (B T) L C", T=self.d_t, L=L)
mask = rearrange(mask, "B T L -> (B T) L")
# cross attn
x = x + self.cross_attn(x, y, mask)
x = rearrange(x, "(B T) S C -> B (T S) C", B=B, T=self.d_t)
y = rearrange(y, "(B T) L C -> B (T L) C", B=B, T=self.d_t)
# mlp
x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm3(x), shift_mlp, scale_mlp)))
y = y + self.drop_path(gate_mlp_y * self.mlp_y(t2i_modulate(self.norm3_y(y), shift_mlp_y, scale_mlp_y)))
y = rearrange(y, "B (T L) C -> B T L C", T=self.d_t, L=L)
return x, y
@MODELS.register_module()
class STDiT_MMDiT(nn.Module):
def __init__(
self,
input_size=(1, 32, 32),
in_channels=4,
patch_size=(1, 2, 2),
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4.0,
class_dropout_prob=0.1,
pred_sigma=True,
drop_path=0.0,
no_temporal_pos_emb=False,
caption_channels=4096,
model_max_length=120,
dtype=torch.float32,
space_scale=1.0,
time_scale=1.0,
freeze=None,
enable_flashattn=False,
enable_layernorm_kernel=False,
enable_sequence_parallelism=False,
qk_norm=False,
):
super().__init__()
self.pred_sigma = pred_sigma
self.in_channels = in_channels
self.out_channels = in_channels * 2 if pred_sigma else in_channels
self.hidden_size = hidden_size
self.patch_size = patch_size
self.input_size = input_size
num_patches = np.prod([input_size[i] // patch_size[i] for i in range(3)])
self.num_patches = num_patches
self.num_temporal = input_size[0] // patch_size[0]
self.num_spatial = num_patches // self.num_temporal
self.num_heads = num_heads
self.dtype = dtype
self.no_temporal_pos_emb = no_temporal_pos_emb
self.depth = depth
self.mlp_ratio = mlp_ratio
self.enable_flashattn = enable_flashattn
self.enable_layernorm_kernel = enable_layernorm_kernel
self.space_scale = space_scale
self.time_scale = time_scale
self.register_buffer("pos_embed", self.get_spatial_pos_embed())
self.register_buffer("pos_embed_temporal", self.get_temporal_pos_embed())
self.x_embedder = PatchEmbed3D(patch_size, in_channels, hidden_size)
self.t_embedder = TimestepEmbedder(hidden_size)
self.t_block = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
self.t_block_y = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
self.t_block_temp = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 3 * hidden_size, bias=True))
self.t_block_y_temp = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 3 * hidden_size, bias=True))
self.y_embedder = CaptionEmbedder(
in_channels=caption_channels,
hidden_size=hidden_size,
uncond_prob=class_dropout_prob,
act_layer=approx_gelu,
token_num=model_max_length,
)
drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)]
self.blocks = nn.ModuleList(
[
STDiTBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=self.mlp_ratio,
drop_path=drop_path[i],
enable_flashattn=self.enable_flashattn,
enable_layernorm_kernel=self.enable_layernorm_kernel,
enable_sequence_parallelism=enable_sequence_parallelism,
d_t=self.num_temporal,
d_s=self.num_spatial,
qk_norm=qk_norm,
)
for i in range(self.depth)
]
)
self.final_layer = T2IFinalLayer(hidden_size, np.prod(self.patch_size), self.out_channels)
# init model
self.initialize_weights()
self.initialize_temporal()
if freeze is not None:
assert freeze in ["not_temporal", "text", "not_attn"]
if freeze == "not_temporal":
self.freeze_not_temporal()
elif freeze == "text":
self.freeze_text()
elif freeze == "not_attn":
self.freeze_not_attn()
# sequence parallel related configs
self.enable_sequence_parallelism = enable_sequence_parallelism
if enable_sequence_parallelism:
self.sp_rank = dist.get_rank(get_sequence_parallel_group())
else:
self.sp_rank = None
def forward(self, x, timestep, y, mask=None):
"""
Forward pass of STDiT.
Args:
x (torch.Tensor): latent representation of video; of shape [B, C, T, H, W]
timestep (torch.Tensor): diffusion time steps; of shape [B]
y (torch.Tensor): representation of prompts; of shape [B, 1, N_token, C]
mask (torch.Tensor): mask for selecting prompt tokens; of shape [B, N_token]
Returns:
x (torch.Tensor): output latent representation; of shape [B, C, T, H, W]
"""
x = x.to(self.dtype)
timestep = timestep.to(self.dtype)
y = y.to(self.dtype)
# embedding
x = self.x_embedder(x) # [B, (THW), C]
x = rearrange(x, "B (T S) C -> B T S C", T=self.num_temporal, S=self.num_spatial)
x = x + self.pos_embed
x = rearrange(x, "B T S C -> B (T S) C")
# shard over the sequence dim if sp is enabled
if self.enable_sequence_parallelism:
x = split_forward_gather_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="down")
t = self.t_embedder(timestep, dtype=x.dtype) # [B, C]
t0 = self.t_block(t) # [B, C]
t_y = self.t_block_y(t)
t0_tmep = self.t_block_temp(t) # [B, C]
t_y_tmep = self.t_block_y_temp(t)
y = self.y_embedder(y, self.training) # [B, 1, N_token, C]
if mask is not None:
if mask.shape[0] != y.shape[0]:
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
y = y.repeat(1, self.num_temporal, 1, 1) # B T L C
mask = mask.unsqueeze(1).repeat(1, self.num_temporal, 1) # B T L
# mask = mask.squeeze(1).squeeze(1)
# y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
# y_lens = mask.sum(dim=1).tolist()
else:
y_lens = [y.shape[2]] * y.shape[0]
y = y.squeeze(1).view(1, -1, x.shape[-1])
# blocks
for i, block in enumerate(self.blocks):
if i == 0:
if self.enable_sequence_parallelism:
tpe = torch.chunk(
self.pos_embed_temporal, dist.get_world_size(get_sequence_parallel_group()), dim=1
)[self.sp_rank].contiguous()
else:
tpe = self.pos_embed_temporal
else:
tpe = None
x, y = auto_grad_checkpoint(block, x, y, t0, t_y, t0_tmep, t_y_tmep, mask, tpe)
if self.enable_sequence_parallelism:
x = gather_forward_split_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="up")
# x.shape: [B, N, C]
# final process
x = self.final_layer(x, t) # [B, N, C=T_p * H_p * W_p * C_out]
x = self.unpatchify(x) # [B, C_out, T, H, W]
# cast to float32 for better accuracy
x = x.to(torch.float32)
return x
def unpatchify(self, x):
"""
Args:
x (torch.Tensor): of shape [B, N, C]
Return:
x (torch.Tensor): of shape [B, C_out, T, H, W]
"""
N_t, N_h, N_w = [self.input_size[i] // self.patch_size[i] for i in range(3)]
T_p, H_p, W_p = self.patch_size
x = rearrange(
x,
"B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)",
N_t=N_t,
N_h=N_h,
N_w=N_w,
T_p=T_p,
H_p=H_p,
W_p=W_p,
C_out=self.out_channels,
)
return x
def unpatchify_old(self, x):
c = self.out_channels
t, h, w = [self.input_size[i] // self.patch_size[i] for i in range(3)]
pt, ph, pw = self.patch_size
x = x.reshape(shape=(x.shape[0], t, h, w, pt, ph, pw, c))
x = rearrange(x, "n t h w r p q c -> n c t r h p w q")
imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
return imgs
def get_spatial_pos_embed(self, grid_size=None):
if grid_size is None:
grid_size = self.input_size[1:]
pos_embed = get_2d_sincos_pos_embed(
self.hidden_size,
(grid_size[0] // self.patch_size[1], grid_size[1] // self.patch_size[2]),
scale=self.space_scale,
)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)
return pos_embed
def get_temporal_pos_embed(self):
pos_embed = get_1d_sincos_pos_embed(
self.hidden_size,
self.input_size[0] // self.patch_size[0],
scale=self.time_scale,
)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)
return pos_embed
def freeze_not_temporal(self):
for n, p in self.named_parameters():
if "attn_temp" not in n:
p.requires_grad = False
def freeze_text(self):
for n, p in self.named_parameters():
if "cross_attn" in n:
p.requires_grad = False
def freeze_not_attn(self):
for n, p in self.named_parameters():
if "attn" not in n:
p.requires_grad = False
if "cross_attn" in n or "attn_temp" in n:
p.requires_grad = False
def initialize_temporal(self):
for block in self.blocks:
nn.init.constant_(block.attn_temp.proj.weight, 0)
nn.init.constant_(block.attn_temp.proj.bias, 0)
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
nn.init.normal_(self.t_block[1].weight, std=0.02)
nn.init.normal_(self.t_block_y[1].weight, std=0.02)
nn.init.normal_(self.t_block_temp[1].weight, std=0.02)
nn.init.normal_(self.t_block_y_temp[1].weight, std=0.02)
# Initialize caption embedding MLP:
nn.init.normal_(self.y_embedder.y_proj.fc1.weight, std=0.02)
nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02)
# Zero-out adaLN modulation layers in PixArt blocks:
for block in self.blocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
@MODELS.register_module("STDiT_MMDiT_XL/2")
def STDiT_MMDiT_XL_2(from_pretrained=None, **kwargs):
model = STDiT_MMDiT(depth=28, hidden_size=1152, patch_size=(1, 2, 2), num_heads=16, **kwargs)
if from_pretrained is not None:
load_checkpoint(model, from_pretrained)
return model
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from einops import rearrange
from timm.models.layers import DropPath
from timm.models.vision_transformer import Mlp
from opensora.acceleration.checkpoint import auto_grad_checkpoint
from opensora.acceleration.communications import gather_forward_split_backward, split_forward_gather_backward
from opensora.acceleration.parallel_states import get_sequence_parallel_group
from opensora.models.layers.blocks import (
Attention,
Attention_QKNorm_RoPE,
MaskedSelfAttention,
CaptionEmbedder,
MultiHeadCrossAttention,
MaskedMultiHeadCrossAttention,
PatchEmbed3D,
SeqParallelAttention,
SeqParallelMultiHeadCrossAttention,
T2IFinalLayer,
TimestepEmbedder,
approx_gelu,
get_1d_sincos_pos_embed,
get_2d_sincos_pos_embed,
get_layernorm,
t2i_modulate,
)
from opensora.registry import MODELS
from opensora.utils.ckpt_utils import load_checkpoint
import ipdb
from opensora.models.layers.timm_uvit import trunc_normal_
class STDiTBlock(nn.Module):
def __init__(
self,
hidden_size,
num_heads,
d_s=None,
d_t=None,
mlp_ratio=4.0,
drop_path=0.0,
enable_flashattn=False,
enable_layernorm_kernel=False,
enable_sequence_parallelism=False,
qk_norm=False,
):
super().__init__()
self.hidden_size = hidden_size
self.enable_flashattn = enable_flashattn
self._enable_sequence_parallelism = enable_sequence_parallelism
if enable_sequence_parallelism:
self.attn_cls = SeqParallelAttention
self.mha_cls = SeqParallelMultiHeadCrossAttention
else: # here
self.self_masked_attn = MaskedSelfAttention
self.attn_cls = Attention_QKNorm_RoPE
self.mha_cls = MaskedMultiHeadCrossAttention
self.norm1 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)
self.norm1_y = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)
self.norm2 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)
self.norm2_y = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)
self.attn = self.self_masked_attn(
hidden_size,
num_heads=num_heads,
qkv_bias=True,
enable_flashattn=enable_flashattn,
qk_norm=qk_norm,
)
self.cross_attn = self.mha_cls(hidden_size, num_heads)
self.norm3 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)
self.norm3_y = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)
self.mlp = Mlp(
in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0
)
self.mlp_y = Mlp(
in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5)
self.scale_shift_table_y = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5)
self.scale_shift_table_temp = nn.Parameter(torch.randn(3, hidden_size) / hidden_size**0.5)
self.scale_shift_table_y_temp = nn.Parameter(torch.randn(3, hidden_size) / hidden_size**0.5)
# temporal attention
self.d_s = d_s
self.d_t = d_t
if self._enable_sequence_parallelism:
sp_size = dist.get_world_size(get_sequence_parallel_group())
# make sure d_t is divisible by sp_size
assert d_t % sp_size == 0
self.d_t = d_t // sp_size
self.attn_temp = self.attn_cls(
hidden_size,
num_heads=num_heads,
qkv_bias=True,
enable_flashattn=self.enable_flashattn,
qk_norm=qk_norm,
)
def forward(self, x, y, t, t_y, t_tmep, t_y_tmep, mask=None, tpe=None):
B, N, C = x.shape
L = y.shape[2] # y: B T L C, mask: B T L
x = rearrange(x, "B (T S) C -> B T S C", T=self.d_t, S=self.d_s)
x_mask = torch.ones(x.shape[:3], device=x.device, dtype=x.dtype) # B T S
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.scale_shift_table[None] + t.reshape(B, 6, -1)
).chunk(6, dim=1)
shift_msa_y, scale_msa_y, gate_msa_y, shift_mlp_y, scale_mlp_y, gate_mlp_y = (
self.scale_shift_table_y[None] + t_y.reshape(B, 6, -1)
).chunk(6, dim=1)
shift_msa_temp, scale_msa_temp, gate_msa_temp = (
self.scale_shift_table_temp[None] + t_tmep.reshape(B, 3, -1)
).chunk(3, dim=1)
shift_msa_y_temp, scale_msa_y_temp, gate_msa_y_temp = (
self.scale_shift_table_y_temp[None] + t_y_tmep.reshape(B, 3, -1)
).chunk(3, dim=1)
x = rearrange(x, "B T S C -> B (T S) C")
y = rearrange(y, "B T L C -> B (T L) C")
x_m = t2i_modulate(self.norm1(x), shift_msa, scale_msa)
y_m = t2i_modulate(self.norm1_y(y), shift_msa_y, scale_msa_y)
x_m = rearrange(x_m, "B (T S) C -> B T S C", T=self.d_t, S=self.d_s)
y_m = rearrange(y_m, "B (T L) C -> B T L C", T=self.d_t, L=L)
xy_m = torch.cat([x_m, y_m], dim=2)
xy_mask = torch.cat([x_mask, mask], dim=2)
xy_mask = rearrange(xy_mask, "B T N -> (B T) N")
# spatial branch
xy_s = rearrange(xy_m, "B T N C -> (B T) N C")
xy_s = self.attn(xy_s, xy_mask)
xy_s = rearrange(xy_s, "(B T) N C -> B T N C", B=B, T=self.d_t)
x_s = xy_s[:, :, :self.d_s, :]
y_s = xy_s[:, :, self.d_s:, :]
x_s = rearrange(x_s, "B T S C -> B (T S) C")
y_s = rearrange(y_s, "B T L C -> B (T L) C")
x = x + self.drop_path(gate_msa * x_s)
y = y + self.drop_path(gate_msa_y * y_s)
x_t = t2i_modulate(self.norm2(x), shift_msa_temp, scale_msa_temp)
y_t = t2i_modulate(self.norm2_y(y), shift_msa_y_temp, scale_msa_y_temp)
x_t = rearrange(x_t, "B (T S) C -> B T S C", T=self.d_t, S=self.d_s)
y_t = rearrange(y_t, "B (T L) C -> B T L C", T=self.d_t, L=L)
xy_t = torch.cat([x_t, y_t], dim=2)
# temporal branch
xy_t = rearrange(xy_t, "B T N C -> (B N) T C")
if tpe is not None:
xy_t = xy_t + tpe
xy_t = self.attn_temp(xy_t)
xy_t = rearrange(xy_t, "(B N) T C -> B T N C", B=B, N=self.d_s+L)
x_t = xy_t[:, :, :self.d_s, :]
y_t = xy_t[:, :, self.d_s:, :]
x_t = rearrange(x_t, "B T S C -> B (T S) C")
y_t = rearrange(y_t, "B T L C -> B (T L) C")
x = x + self.drop_path(gate_msa_temp * x_t)
y = y + self.drop_path(gate_msa_y_temp * y_t)
x = rearrange(x, "B (T S) C -> (B T) S C", T=self.d_t, S=self.d_s)
y = rearrange(y, "B (T L) C -> (B T) L C", T=self.d_t, L=L)
mask = rearrange(mask, "B T L -> (B T) L")
# cross attn
x = x + self.cross_attn(x, y, mask)
x = rearrange(x, "(B T) S C -> B (T S) C", B=B, T=self.d_t)
y = rearrange(y, "(B T) L C -> B (T L) C", B=B, T=self.d_t)
# mlp
x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm3(x), shift_mlp, scale_mlp)))
y = y + self.drop_path(gate_mlp_y * self.mlp_y(t2i_modulate(self.norm3_y(y), shift_mlp_y, scale_mlp_y)))
y = rearrange(y, "B (T L) C -> B T L C", T=self.d_t, L=L)
return x, y
@MODELS.register_module()
class STDiT_MMDiTQK(nn.Module):
def __init__(
self,
input_size=(1, 32, 32),
in_channels=4,
patch_size=(1, 2, 2),
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4.0,
class_dropout_prob=0.1,
pred_sigma=True,
drop_path=0.0,
no_temporal_pos_emb=False,
caption_channels=4096,
model_max_length=120,
dtype=torch.float32,
space_scale=1.0,
time_scale=1.0,
freeze=None,
enable_flashattn=False,
enable_layernorm_kernel=False,
enable_sequence_parallelism=False,
qk_norm=True,
):
super().__init__()
self.pred_sigma = pred_sigma
self.in_channels = in_channels
self.out_channels = in_channels * 2 if pred_sigma else in_channels
self.hidden_size = hidden_size
self.patch_size = patch_size
self.input_size = input_size
num_patches = np.prod([input_size[i] // patch_size[i] for i in range(3)])
self.num_patches = num_patches
self.num_temporal = input_size[0] // patch_size[0]
self.num_spatial = num_patches // self.num_temporal
self.num_heads = num_heads
self.dtype = dtype
self.no_temporal_pos_emb = no_temporal_pos_emb
self.depth = depth
self.mlp_ratio = mlp_ratio
self.enable_flashattn = enable_flashattn
self.enable_layernorm_kernel = enable_layernorm_kernel
self.space_scale = space_scale
self.time_scale = time_scale
self.register_buffer("pos_embed", self.get_spatial_pos_embed())
self.register_buffer("pos_embed_temporal", self.get_temporal_pos_embed())
self.x_embedder = PatchEmbed3D(patch_size, in_channels, hidden_size)
self.t_embedder = TimestepEmbedder(hidden_size)
self.t_block = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
self.t_block_y = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
self.t_block_temp = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 3 * hidden_size, bias=True))
self.t_block_y_temp = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 3 * hidden_size, bias=True))
self.y_embedder = CaptionEmbedder(
in_channels=caption_channels,
hidden_size=hidden_size,
uncond_prob=class_dropout_prob,
act_layer=approx_gelu,
token_num=model_max_length,
)
drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)]
self.blocks = nn.ModuleList(
[
STDiTBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=self.mlp_ratio,
drop_path=drop_path[i],
enable_flashattn=self.enable_flashattn,
enable_layernorm_kernel=self.enable_layernorm_kernel,
enable_sequence_parallelism=enable_sequence_parallelism,
d_t=self.num_temporal,
d_s=self.num_spatial,
qk_norm=qk_norm,
)
for i in range(self.depth)
]
)
self.final_layer = T2IFinalLayer(hidden_size, np.prod(self.patch_size), self.out_channels)
# init model
self.initialize_weights()
self.initialize_temporal()
if freeze is not None:
assert freeze in ["not_temporal", "text"]
if freeze == "not_temporal":
self.freeze_not_temporal()
elif freeze == "text":
self.freeze_text()
# sequence parallel related configs
self.enable_sequence_parallelism = enable_sequence_parallelism
if enable_sequence_parallelism:
self.sp_rank = dist.get_rank(get_sequence_parallel_group())
else:
self.sp_rank = None
def forward(self, x, timestep, y, mask=None):
"""
Forward pass of STDiT.
Args:
x (torch.Tensor): latent representation of video; of shape [B, C, T, H, W]
timestep (torch.Tensor): diffusion time steps; of shape [B]
y (torch.Tensor): representation of prompts; of shape [B, 1, N_token, C]
mask (torch.Tensor): mask for selecting prompt tokens; of shape [B, N_token]
Returns:
x (torch.Tensor): output latent representation; of shape [B, C, T, H, W]
"""
x = x.to(self.dtype)
timestep = timestep.to(self.dtype)
y = y.to(self.dtype)
# embedding
x = self.x_embedder(x) # [B, (THW), C]
x = rearrange(x, "B (T S) C -> B T S C", T=self.num_temporal, S=self.num_spatial)
x = x + self.pos_embed
x = rearrange(x, "B T S C -> B (T S) C")
# shard over the sequence dim if sp is enabled
if self.enable_sequence_parallelism:
x = split_forward_gather_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="down")
t = self.t_embedder(timestep, dtype=x.dtype) # [B, C]
t0 = self.t_block(t) # [B, C]
t_y = self.t_block_y(t)
t0_tmep = self.t_block_temp(t) # [B, C]
t_y_tmep = self.t_block_y_temp(t)
y = self.y_embedder(y, self.training) # [B, 1, N_token, C]
if mask is not None:
if mask.shape[0] != y.shape[0]:
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
y = y.repeat(1, self.num_temporal, 1, 1) # B T L C
mask = mask.unsqueeze(1).repeat(1, self.num_temporal, 1) # B T L
# mask = mask.squeeze(1).squeeze(1)
# y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
# y_lens = mask.sum(dim=1).tolist()
else:
y_lens = [y.shape[2]] * y.shape[0]
y = y.squeeze(1).view(1, -1, x.shape[-1])
# blocks
for i, block in enumerate(self.blocks):
if i == 0:
if self.enable_sequence_parallelism:
tpe = torch.chunk(
self.pos_embed_temporal, dist.get_world_size(get_sequence_parallel_group()), dim=1
)[self.sp_rank].contiguous()
else:
tpe = self.pos_embed_temporal
else:
tpe = None
x, y = auto_grad_checkpoint(block, x, y, t0, t_y, t0_tmep, t_y_tmep, mask, tpe)
if self.enable_sequence_parallelism:
x = gather_forward_split_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="up")
# x.shape: [B, N, C]
# final process
x = self.final_layer(x, t) # [B, N, C=T_p * H_p * W_p * C_out]
x = self.unpatchify(x) # [B, C_out, T, H, W]
# cast to float32 for better accuracy
x = x.to(torch.float32)
return x
def unpatchify(self, x):
"""
Args:
x (torch.Tensor): of shape [B, N, C]
Return:
x (torch.Tensor): of shape [B, C_out, T, H, W]
"""
N_t, N_h, N_w = [self.input_size[i] // self.patch_size[i] for i in range(3)]
T_p, H_p, W_p = self.patch_size
x = rearrange(
x,
"B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)",
N_t=N_t,
N_h=N_h,
N_w=N_w,
T_p=T_p,
H_p=H_p,
W_p=W_p,
C_out=self.out_channels,
)
return x
def unpatchify_old(self, x):
c = self.out_channels
t, h, w = [self.input_size[i] // self.patch_size[i] for i in range(3)]
pt, ph, pw = self.patch_size
x = x.reshape(shape=(x.shape[0], t, h, w, pt, ph, pw, c))
x = rearrange(x, "n t h w r p q c -> n c t r h p w q")
imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
return imgs
def get_spatial_pos_embed(self, grid_size=None):
if grid_size is None:
grid_size = self.input_size[1:]
pos_embed = get_2d_sincos_pos_embed(
self.hidden_size,
(grid_size[0] // self.patch_size[1], grid_size[1] // self.patch_size[2]),
scale=self.space_scale,
)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)
return pos_embed
def get_temporal_pos_embed(self):
pos_embed = get_1d_sincos_pos_embed(
self.hidden_size,
self.input_size[0] // self.patch_size[0],
scale=self.time_scale,
)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)
return pos_embed
def freeze_not_temporal(self):
for n, p in self.named_parameters():
if "attn_temp" not in n:
p.requires_grad = False
def freeze_text(self):
for n, p in self.named_parameters():
if "cross_attn" in n:
p.requires_grad = False
def initialize_temporal(self):
for block in self.blocks:
nn.init.constant_(block.attn_temp.proj.weight, 0)
nn.init.constant_(block.attn_temp.proj.bias, 0)
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
nn.init.normal_(self.t_block[1].weight, std=0.02)
nn.init.normal_(self.t_block_y[1].weight, std=0.02)
nn.init.normal_(self.t_block_temp[1].weight, std=0.02)
nn.init.normal_(self.t_block_y_temp[1].weight, std=0.02)
# Initialize caption embedding MLP:
nn.init.normal_(self.y_embedder.y_proj.fc1.weight, std=0.02)
nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02)
# Zero-out adaLN modulation layers in PixArt blocks:
for block in self.blocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
@MODELS.register_module("STDiT_MMDiTQK_XL/2")
def STDiT_MMDiTQK_XL_2(from_pretrained=None, **kwargs):
model = STDiT_MMDiTQK(depth=28, hidden_size=1152, patch_size=(1, 2, 2), num_heads=16, **kwargs)
if from_pretrained is not None:
load_checkpoint(model, from_pretrained)
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment