Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
STAR
Commits
1f5da520
Commit
1f5da520
authored
Dec 05, 2025
by
yangzhong
Browse files
git init
parents
Pipeline
#3144
failed with stages
in 0 seconds
Changes
326
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4507 additions
and
0 deletions
+4507
-0
utils_data/opensora/models/layers/blocks_vis.py
utils_data/opensora/models/layers/blocks_vis.py
+1050
-0
utils_data/opensora/models/layers/timm_uvit.py
utils_data/opensora/models/layers/timm_uvit.py
+113
-0
utils_data/opensora/models/pixart/__init__.py
utils_data/opensora/models/pixart/__init__.py
+1
-0
utils_data/opensora/models/pixart/__pycache__/__init__.cpython-39.pyc
...pensora/models/pixart/__pycache__/__init__.cpython-39.pyc
+0
-0
utils_data/opensora/models/pixart/__pycache__/pixart.cpython-39.pyc
.../opensora/models/pixart/__pycache__/pixart.cpython-39.pyc
+0
-0
utils_data/opensora/models/pixart/pixart.py
utils_data/opensora/models/pixart/pixart.py
+389
-0
utils_data/opensora/models/stdit/__init__.py
utils_data/opensora/models/stdit/__init__.py
+20
-0
utils_data/opensora/models/stdit/__pycache__/__init__.cpython-39.pyc
...opensora/models/stdit/__pycache__/__init__.cpython-39.pyc
+0
-0
utils_data/opensora/models/stdit/__pycache__/stdit.cpython-39.pyc
...ta/opensora/models/stdit/__pycache__/stdit.cpython-39.pyc
+0
-0
utils_data/opensora/models/stdit/__pycache__/stdit_freq.cpython-39.pyc
...ensora/models/stdit/__pycache__/stdit_freq.cpython-39.pyc
+0
-0
utils_data/opensora/models/stdit/__pycache__/stdit_mmdit.cpython-39.pyc
...nsora/models/stdit/__pycache__/stdit_mmdit.cpython-39.pyc
+0
-0
utils_data/opensora/models/stdit/__pycache__/stdit_qknorm_rope.cpython-39.pyc
...models/stdit/__pycache__/stdit_qknorm_rope.cpython-39.pyc
+0
-0
utils_data/opensora/models/stdit/stdit.py
utils_data/opensora/models/stdit/stdit.py
+391
-0
utils_data/opensora/models/stdit/stdit_controlnet.py
utils_data/opensora/models/stdit/stdit_controlnet.py
+286
-0
utils_data/opensora/models/stdit/stdit_controlnet_freq.py
utils_data/opensora/models/stdit/stdit_controlnet_freq.py
+321
-0
utils_data/opensora/models/stdit/stdit_controlnet_mvdit.py
utils_data/opensora/models/stdit/stdit_controlnet_mvdit.py
+293
-0
utils_data/opensora/models/stdit/stdit_controlnet_qknorm.py
utils_data/opensora/models/stdit/stdit_controlnet_qknorm.py
+287
-0
utils_data/opensora/models/stdit/stdit_freq.py
utils_data/opensora/models/stdit/stdit_freq.py
+419
-0
utils_data/opensora/models/stdit/stdit_mmdit.py
utils_data/opensora/models/stdit/stdit_mmdit.py
+473
-0
utils_data/opensora/models/stdit/stdit_mmdit_qk.py
utils_data/opensora/models/stdit/stdit_mmdit_qk.py
+464
-0
No files found.
utils_data/opensora/models/layers/blocks_vis.py
0 → 100644
View file @
1f5da520
# # This source code is licensed under the license found in the
# # LICENSE file in the root directory of this source tree.
# # --------------------------------------------------------
# # References:
# # PixArt: https://github.com/PixArt-alpha/PixArt-alpha
# # Latte: https://github.com/Vchitect/Latte
# # DiT: https://github.com/facebookresearch/DiT/tree/main
# # GLIDE: https://github.com/openai/glide-text2im
# # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
# # --------------------------------------------------------
# import math
# from typing import KeysView
# import numpy as np
# import torch
# import torch.distributed as dist
# import torch.nn as nn
# import torch.nn.functional as F
# import torch.utils.checkpoint
# import xformers.ops
# from einops import rearrange
# from timm.models.vision_transformer import Mlp
# from opensora.acceleration.communications import all_to_all, split_forward_gather_backward
# from opensora.acceleration.parallel_states import get_sequence_parallel_group
# import ipdb
# import cv2
# import os
# approx_gelu = lambda: nn.GELU(approximate="tanh")
# class LlamaRMSNorm(nn.Module):
# def __init__(self, hidden_size, eps=1e-6):
# """
# LlamaRMSNorm is equivalent to T5LayerNorm
# """
# super().__init__()
# self.weight = nn.Parameter(torch.ones(hidden_size))
# self.variance_epsilon = eps
# def forward(self, hidden_states):
# #ipdb.set_trace()
# input_dtype = hidden_states.dtype
# hidden_states = hidden_states.to(torch.float32)
# variance = hidden_states.pow(2).mean(-1, keepdim=True)
# hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
# #ipdb.set_trace()
# return self.weight * hidden_states.to(input_dtype)
# def get_layernorm(hidden_size: torch.Tensor, eps: float, affine: bool, use_kernel: bool):
# if use_kernel:
# try:
# from apex.normalization import FusedLayerNorm
# return FusedLayerNorm(hidden_size, elementwise_affine=affine, eps=eps)
# except ImportError:
# raise RuntimeError("FusedLayerNorm not available. Please install apex.")
# else:
# return nn.LayerNorm(hidden_size, eps, elementwise_affine=affine)
# def modulate(norm_func, x, shift, scale):
# # Suppose x is (B, N, D), shift is (B, D), scale is (B, D)
# dtype = x.dtype
# x = norm_func(x.to(torch.float32)).to(dtype)
# x = x * (scale.unsqueeze(1) + 1) + shift.unsqueeze(1)
# x = x.to(dtype)
# return x
# def t2i_modulate(x, shift, scale):
# return x * (1 + scale) + shift
# def get_attn_mask(query, key, idx):
# scale = 1.0 / query.shape[-1] ** 0.5
# query = query * scale
# query = query.transpose(0, 1)
# key = key.transpose(0, 1)
# attn = query @ key.transpose(-2, -1)
# attn = attn.softmax(-1) # H S L
# #attn = F.dropout(attn, p=0.0)
# # attn[attn>0.5]=1
# # attn[attn<=0.5]=0
# # attn[attn==1]=255
# attn = attn * 255
# H, S, L = attn.shape
# hight = 64
# width = 64
# for h in range(H):
# for l in range(L):
# map = attn[h, :, l]
# map = rearrange(map, '(H W) -> H W', H=hight, W=width)
# map = F.interpolate(map.unsqueeze(0).unsqueeze(0), size=[256,256],mode='nearest')
# np_array = map.squeeze(0).squeeze(0).detach().cpu().numpy()
# image = cv2.imwrite(os.path.join("/mnt/bn/yh-volume0/code/debug/code/OpenSora/outputs/vis", 'map'+str(idx)+'_head'+str(h)+'_word'+str(l)+'.jpg'), np_array)
# print("成功保存图像!")
# # ===============================================
# # General-purpose Layers
# # ===============================================
# class PatchEmbed3D(nn.Module):
# """Video to Patch Embedding.
# Args:
# patch_size (int): Patch token size. Default: (2,4,4).
# in_chans (int): Number of input video channels. Default: 3.
# embed_dim (int): Number of linear projection output channels. Default: 96.
# norm_layer (nn.Module, optional): Normalization layer. Default: None
# """
# def __init__(
# self,
# patch_size=(2, 4, 4),
# in_chans=3,
# embed_dim=96,
# norm_layer=None,
# flatten=True,
# ):
# super().__init__()
# self.patch_size = patch_size
# self.flatten = flatten
# self.in_chans = in_chans
# self.embed_dim = embed_dim
# self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
# if norm_layer is not None:
# self.norm = norm_layer(embed_dim)
# else:
# self.norm = None
# def forward(self, x):
# """Forward function."""
# # padding
# _, _, D, H, W = x.size()
# if W % self.patch_size[2] != 0:
# x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))
# if H % self.patch_size[1] != 0:
# x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))
# if D % self.patch_size[0] != 0:
# x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0]))
# x = self.proj(x) # (B C T H W)
# if self.norm is not None:
# D, Wh, Ww = x.size(2), x.size(3), x.size(4)
# x = x.flatten(2).transpose(1, 2)
# x = self.norm(x)
# x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)
# if self.flatten:
# x = x.flatten(2).transpose(1, 2) # BCTHW -> BNC
# return x
# class Attention(nn.Module):
# def __init__(
# self,
# dim: int,
# num_heads: int = 8,
# qkv_bias: bool = False,
# qk_norm: bool = False,
# attn_drop: float = 0.0,
# proj_drop: float = 0.0,
# norm_layer: nn.Module = nn.LayerNorm,
# enable_flashattn: bool = False,
# ) -> None:
# super().__init__()
# assert dim % num_heads == 0, "dim should be divisible by num_heads"
# self.dim = dim
# self.num_heads = num_heads
# self.head_dim = dim // num_heads
# self.scale = self.head_dim**-0.5
# self.enable_flashattn = enable_flashattn
# self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
# self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
# self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
# self.attn_drop = nn.Dropout(attn_drop)
# self.proj = nn.Linear(dim, dim)
# self.proj_drop = nn.Dropout(proj_drop)
# def forward(self, x: torch.Tensor) -> torch.Tensor:
# B, N, C = x.shape
# qkv = self.qkv(x)
# qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
# if self.enable_flashattn: # here
# qkv_permute_shape = (2, 0, 1, 3, 4)
# else:
# qkv_permute_shape = (2, 0, 3, 1, 4)
# qkv = qkv.view(qkv_shape).permute(qkv_permute_shape)
# q, k, v = qkv.unbind(0)
# q, k = self.q_norm(q), self.k_norm(k)
# if self.enable_flashattn:
# from flash_attn import flash_attn_func
# x = flash_attn_func(
# q,
# k,
# v,
# dropout_p=self.attn_drop.p if self.training else 0.0,
# softmax_scale=self.scale,
# )
# else:
# dtype = q.dtype
# q = q * self.scale
# attn = q @ k.transpose(-2, -1) # translate attn to float32
# attn = attn.to(torch.float32)
# attn = attn.softmax(dim=-1)
# attn = attn.to(dtype) # cast back attn to original dtype
# attn = self.attn_drop(attn)
# x = attn @ v
# x_output_shape = (B, N, C)
# if not self.enable_flashattn:
# x = x.transpose(1, 2)
# x = x.reshape(x_output_shape)
# x = self.proj(x)
# x = self.proj_drop(x)
# return x
# class Attention_QKNorm_RoPE(nn.Module):
# def __init__(
# self,
# dim: int,
# num_heads: int = 8,
# qkv_bias: bool = False,
# qk_norm: bool = False,
# attn_drop: float = 0.0,
# proj_drop: float = 0.0,
# norm_layer: nn.Module = LlamaRMSNorm,
# enable_flashattn: bool = False,
# rope=None,
# ) -> None:
# super().__init__()
# assert dim % num_heads == 0, "dim should be divisible by num_heads"
# self.dim = dim
# self.num_heads = num_heads
# self.head_dim = dim // num_heads
# self.scale = self.head_dim**-0.5
# self.enable_flashattn = enable_flashattn
# self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
# self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
# self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
# self.attn_drop = nn.Dropout(attn_drop)
# self.proj = nn.Linear(dim, dim)
# self.proj_drop = nn.Dropout(proj_drop)
# self.rotary_emb = rope
# def forward(self, x: torch.Tensor) -> torch.Tensor:
# B, N, C = x.shape
# qkv = self.qkv(x)
# qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
# if self.enable_flashattn:
# qkv_permute_shape = (2, 0, 1, 3, 4)
# else:
# qkv_permute_shape = (2, 0, 3, 1, 4)
# qkv = qkv.view(qkv_shape).permute(qkv_permute_shape)
# q, k, v = qkv.unbind(0)
# #ipdb.set_trace()
# if self.rotary_emb is not None:
# q = self.rotary_emb(q)
# k = self.rotary_emb(k)
# #ipdb.set_trace()
# q, k = self.q_norm(q), self.k_norm(k)
# #ipdb.set_trace()
# if self.enable_flashattn:
# from flash_attn import flash_attn_func
# x = flash_attn_func(
# q,
# k,
# v,
# dropout_p=self.attn_drop.p if self.training else 0.0,
# softmax_scale=self.scale,
# )
# else:
# dtype = q.dtype
# q = q * self.scale
# attn = q @ k.transpose(-2, -1) # translate attn to float32
# attn = attn.to(torch.float32)
# attn = attn.softmax(dim=-1)
# attn = attn.to(dtype) # cast back attn to original dtype
# attn = self.attn_drop(attn)
# x = attn @ v
# x_output_shape = (B, N, C)
# if not self.enable_flashattn:
# x = x.transpose(1, 2)
# x = x.reshape(x_output_shape)
# x = self.proj(x)
# x = self.proj_drop(x)
# return x
# class MaskedSelfAttention(nn.Module):
# def __init__(
# self,
# dim: int,
# num_heads: int = 8,
# qkv_bias: bool = False,
# qk_norm: bool = False,
# attn_drop: float = 0.0,
# proj_drop: float = 0.0,
# norm_layer: nn.Module = LlamaRMSNorm,
# enable_flashattn: bool = False,
# rope=None,
# ) -> None:
# super().__init__()
# assert dim % num_heads == 0, "dim should be divisible by num_heads"
# self.dim = dim
# self.num_heads = num_heads
# self.head_dim = dim // num_heads
# self.scale = self.head_dim**-0.5
# self.enable_flashattn = enable_flashattn
# self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
# self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
# self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
# self.attn_drop = nn.Dropout(attn_drop)
# self.proj = nn.Linear(dim, dim)
# self.proj_drop = nn.Dropout(proj_drop)
# self.rotary_emb = rope
# def forward(self, x, mask):
# B, N, C = x.shape
# qkv = self.qkv(x)
# qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
# qkv_permute_shape = (2, 0, 3, 1, 4)
# qkv = qkv.view(qkv_shape).permute(qkv_permute_shape)
# q, k, v = qkv.unbind(0) # B H N C
# #ipdb.set_trace()
# if self.rotary_emb is not None:
# q = self.rotary_emb(q)
# k = self.rotary_emb(k)
# #ipdb.set_trace()
# q, k = self.q_norm(q), self.k_norm(k)
# #ipdb.set_trace()
# mask = mask.unsqueeze(1).unsqueeze(1).repeat(1, self.num_heads, 1, 1).to(torch.float32) # B H 1 N
# dtype = q.dtype
# q = q * self.scale
# attn = q @ k.transpose(-2, -1) # translate attn to float32
# attn = attn.to(torch.float32)
# attn = attn.masked_fill(mask == 0, -1e9)
# attn = attn.softmax(dim=-1)
# attn = attn.to(dtype) # cast back attn to original dtype
# attn = self.attn_drop(attn)
# x = attn @ v
# x_output_shape = (B, N, C)
# x = x.transpose(1, 2)
# x = x.reshape(x_output_shape)
# x = self.proj(x)
# x = self.proj_drop(x)
# return x
# class SeqParallelAttention(Attention):
# def __init__(
# self,
# dim: int,
# num_heads: int = 8,
# qkv_bias: bool = False,
# qk_norm: bool = False,
# attn_drop: float = 0.0,
# proj_drop: float = 0.0,
# norm_layer: nn.Module = nn.LayerNorm,
# enable_flashattn: bool = False,
# ) -> None:
# super().__init__(
# dim=dim,
# num_heads=num_heads,
# qkv_bias=qkv_bias,
# qk_norm=qk_norm,
# attn_drop=attn_drop,
# proj_drop=proj_drop,
# norm_layer=norm_layer,
# enable_flashattn=enable_flashattn,
# )
# def forward(self, x: torch.Tensor) -> torch.Tensor:
# B, N, C = x.shape # for sequence parallel here, the N is a local sequence length
# qkv = self.qkv(x)
# qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
# qkv = qkv.view(qkv_shape)
# sp_group = get_sequence_parallel_group()
# # apply all_to_all to gather sequence and split attention heads
# # [B, SUB_N, 3, NUM_HEAD, HEAD_DIM] -> [B, N, 3, NUM_HEAD_PER_DEVICE, HEAD_DIM]
# qkv = all_to_all(qkv, sp_group, scatter_dim=3, gather_dim=1)
# if self.enable_flashattn:
# qkv_permute_shape = (2, 0, 1, 3, 4) # [3, B, N, NUM_HEAD_PER_DEVICE, HEAD_DIM]
# else:
# qkv_permute_shape = (2, 0, 3, 1, 4) # [3, B, NUM_HEAD_PER_DEVICE, N, HEAD_DIM]
# qkv = qkv.permute(qkv_permute_shape)
# q, k, v = qkv.unbind(0)
# q, k = self.q_norm(q), self.k_norm(k)
# if self.enable_flashattn:
# from flash_attn import flash_attn_func
# x = flash_attn_func(
# q,
# k,
# v,
# dropout_p=self.attn_drop.p if self.training else 0.0,
# softmax_scale=self.scale,
# )
# else:
# dtype = q.dtype
# q = q * self.scale
# attn = q @ k.transpose(-2, -1) # translate attn to float32
# attn = attn.to(torch.float32)
# attn = attn.softmax(dim=-1)
# attn = attn.to(dtype) # cast back attn to original dtype
# attn = self.attn_drop(attn)
# x = attn @ v
# if not self.enable_flashattn:
# x = x.transpose(1, 2)
# # apply all to all to gather back attention heads and split sequence
# # [B, N, NUM_HEAD_PER_DEVICE, HEAD_DIM] -> [B, SUB_N, NUM_HEAD, HEAD_DIM]
# x = all_to_all(x, sp_group, scatter_dim=1, gather_dim=2)
# # reshape outputs back to [B, N, C]
# x_output_shape = (B, N, C)
# x = x.reshape(x_output_shape)
# x = self.proj(x)
# x = self.proj_drop(x)
# return x
# class MultiHeadCrossAttention(nn.Module):
# def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0):
# super(MultiHeadCrossAttention, self).__init__()
# assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
# self.d_model = d_model
# self.num_heads = num_heads
# self.head_dim = d_model // num_heads
# self.q_linear = nn.Linear(d_model, d_model)
# self.kv_linear = nn.Linear(d_model, d_model * 2)
# self.attn_drop = nn.Dropout(attn_drop)
# self.proj = nn.Linear(d_model, d_model)
# self.proj_drop = nn.Dropout(proj_drop)
# def forward(self, x, cond, mask=None, i=0, t=0):
# # query/value: img tokens; key: condition; mask: if padding tokens
# B, N, C = x.shape
# q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
# kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
# k, v = kv.unbind(2)
# #ipdb.set_trace()
# attn_bias = None
# if mask is not None:
# attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask)
# # vis
# print(t[0].item())
# if i == 27 and t[0].item() >= 480.0 and t[0].item() <= 500.0:
# q1 = q[:, :N, :, :].squeeze(0) # S H C
# q2 = q[:, N:, :, :].squeeze(0) # S H C
# k1 = k[:, :mask[0], :, :].squeeze(0) # L H C
# k2 = k[:, mask[0]:, :, :].squeeze(0) # L H C
# get_attn_mask(q1, k1, 1)
# get_attn_mask(q2, k2, 2)
# # vis
# x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
# #ipdb.set_trace()
# x = x.view(B, -1, C)
# x = self.proj(x)
# x = self.proj_drop(x)
# return x
# class MaskedMultiHeadCrossAttention(nn.Module):
# def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0):
# super(MaskedMultiHeadCrossAttention, self).__init__()
# assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
# self.d_model = d_model
# self.num_heads = num_heads
# self.head_dim = d_model // num_heads
# self.q_linear = nn.Linear(d_model, d_model)
# self.kv_linear = nn.Linear(d_model, d_model * 2)
# self.attn_drop = nn.Dropout(attn_drop)
# self.proj = nn.Linear(d_model, d_model)
# self.proj_drop = nn.Dropout(proj_drop)
# def forward(self, x, cond, mask=None):
# # query/value: img tokens; key: condition; mask: if padding tokens
# B, S, C = x.shape
# L = cond.shape[1]
# q = self.q_linear(x).view(B, S, self.num_heads, self.head_dim)
# kv = self.kv_linear(cond).view(B, L, 2, self.num_heads, self.head_dim)
# k, v = kv.unbind(2)
# #ipdb.set_trace()
# attn_bias = None
# if mask is not None:
# attn_bias = mask.unsqueeze(1).unsqueeze(1).repeat(1, self.num_heads, S, 1).to(q.dtype) # B H S L
# exp = -1e9
# attn_bias[attn_bias==0] = exp
# attn_bias[attn_bias==1] = 0
# x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
# #ipdb.set_trace()
# x = x.view(B, -1, C)
# x = self.proj(x)
# x = self.proj_drop(x)
# return x
# class LongShortMultiHeadCrossAttention(nn.Module):
# def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0):
# super(LongShortMultiHeadCrossAttention, self).__init__()
# assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
# self.d_model = d_model
# self.num_heads = num_heads
# self.head_dim = d_model // num_heads
# self.q_linear = nn.Linear(d_model, d_model)
# self.kv_linear = nn.Linear(d_model, d_model * 2)
# self.attn_drop = nn.Dropout(attn_drop)
# self.proj = nn.Linear(d_model, d_model)
# self.proj_drop = nn.Dropout(proj_drop)
# def forward(self, x, cond, mask=None):
# # query/value: img tokens; key: condition; mask: if padding tokens
# B, N, C = x.shape
# M = cond.shape[1]
# q = self.q_linear(x).view(B, N, self.num_heads, self.head_dim)
# kv = self.kv_linear(cond).view(B, M, 2, self.num_heads, self.head_dim)
# k, v = kv.unbind(2)
# attn_bias = None
# if mask is not None:
# attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask)
# x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
# x = x.view(B, N, C)
# x = self.proj(x)
# x = self.proj_drop(x)
# return x
# class MultiHeadV2TCrossAttention(nn.Module):
# def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0):
# super(MultiHeadV2TCrossAttention, self).__init__()
# assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
# self.d_model = d_model
# self.num_heads = num_heads
# self.head_dim = d_model // num_heads
# self.q_linear = nn.Linear(d_model, d_model)
# self.kv_linear = nn.Linear(d_model, d_model * 2)
# self.attn_drop = nn.Dropout(attn_drop)
# self.proj = nn.Linear(d_model, d_model)
# self.proj_drop = nn.Dropout(proj_drop)
# def forward(self, x, cond, mask=None):
# # query/value: condition; key: img tokens; mask: if padding tokens
# B, N, C = cond.shape
# q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
# kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
# k, v = kv.unbind(2)
# #ipdb.set_trace()
# attn_bias = None
# if mask is not None:
# attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens(mask, [N] * B)
# x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
# #ipdb.set_trace()
# x = x.view(B, -1, C)
# x = self.proj(x)
# x = self.proj_drop(x)
# return x
# class MultiHeadT2VCrossAttention(nn.Module):
# def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0):
# super(MultiHeadT2VCrossAttention, self).__init__()
# assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
# self.d_model = d_model
# self.num_heads = num_heads
# self.head_dim = d_model // num_heads
# self.q_linear = nn.Linear(d_model, d_model)
# self.kv_linear = nn.Linear(d_model, d_model * 2)
# self.attn_drop = nn.Dropout(attn_drop)
# self.proj = nn.Linear(d_model, d_model)
# self.proj_drop = nn.Dropout(proj_drop)
# def forward(self, x, cond, mask=None):
# # query/value: img tokens; key: condition; mask: if padding tokens
# #ipdb.set_trace()
# B, T, N, C = x.shape
# x = rearrange(x, 'B T N C -> (B T) N C')
# q = self.q_linear(x)
# q = rearrange(q, '(B T) N C -> B T N C', T=T)
# q = q.view(1, -1, self.num_heads, self.head_dim) # 1(B T N) H C
# kv = self.kv_linear(cond)
# kv = kv.view(1, -1, 2, self.num_heads, self.head_dim) # 1 N 2 H C
# k, v = kv.unbind(2)
# #ipdb.set_trace()
# attn_bias = None
# if mask is not None:
# #mask = [m for m in mask for _ in range(T)]
# attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * (B*T), mask)
# x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
# #ipdb.set_trace()
# x = x.view(B, T, N, C)
# x = rearrange(x, 'B T N C -> (B T) N C')
# x = self.proj(x)
# x = self.proj_drop(x)
# x = rearrange(x, '(B T) N C -> B T N C', T=T)
# return x
# class FormerMultiHeadV2TCrossAttention(nn.Module):
# def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0):
# super(FormerMultiHeadV2TCrossAttention, self).__init__()
# assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
# self.d_model = d_model
# self.num_heads = num_heads
# self.head_dim = d_model // num_heads
# self.q_linear = nn.Linear(d_model, d_model)
# self.kv_linear = nn.Linear(d_model, d_model * 2)
# self.attn_drop = nn.Dropout(attn_drop)
# self.proj = nn.Linear(d_model, d_model)
# self.proj_drop = nn.Dropout(proj_drop)
# def forward(self, x, cond, mask=None):
# # x: text tokens; cond: img tokens; mask: if padding tokens
# #ipdb.set_trace()
# _, N, C = x.shape # 1 N C
# B, T, _, _ = cond.shape
# cond = rearrange(cond, 'B T N C -> (B T) N C')
# q = self.q_linear(x)
# q = q.view(1, -1, self.num_heads, self.head_dim) # 1 N H C
# kv = self.kv_linear(cond)
# kv = rearrange(kv, '(B T) N C -> B T N C', B=B)
# M = kv.shape[2] # M = H * W
# former_frame_index = torch.arange(T) - 1
# former_frame_index[0] = 0
# #ipdb.set_trace()
# former_kv = kv[:, former_frame_index]
# former_kv = former_kv.view(1, -1, 2, self.num_heads, self.head_dim) # 1(B T N) 2 H C
# former_k, former_v = former_kv.unbind(2)
# #ipdb.set_trace()
# attn_bias = None
# if mask is not None:
# #mask = [m for m in mask for _ in range(T)]
# attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens(mask, [M] * (B*T))
# x = xformers.ops.memory_efficient_attention(q, former_k, former_v, p=self.attn_drop.p, attn_bias=attn_bias)
# #ipdb.set_trace()
# x = x.view(1, -1, C)
# x = self.proj(x)
# x = self.proj_drop(x)
# return x
# class LatterMultiHeadV2TCrossAttention(nn.Module):
# def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0):
# super(LatterMultiHeadV2TCrossAttention, self).__init__()
# assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
# self.d_model = d_model
# self.num_heads = num_heads
# self.head_dim = d_model // num_heads
# self.q_linear = nn.Linear(d_model, d_model)
# self.kv_linear = nn.Linear(d_model, d_model * 2)
# self.attn_drop = nn.Dropout(attn_drop)
# self.proj = nn.Linear(d_model, d_model)
# self.proj_drop = nn.Dropout(proj_drop)
# def forward(self, x, cond, mask=None):
# # x: text tokens; cond: img tokens; mask: if padding tokens
# #ipdb.set_trace()
# _, N, C = x.shape # 1 N C
# B, T, _, _ = cond.shape
# cond = rearrange(cond, 'B T N C -> (B T) N C')
# q = self.q_linear(x)
# q = q.view(1, -1, self.num_heads, self.head_dim) # 1 N H C
# kv = self.kv_linear(cond)
# kv = rearrange(kv, '(B T) N C -> B T N C', T=T)
# M = kv.shape[2] # M = H * W
# latter_frame_index = torch.arange(T) + 1
# latter_frame_index[-1] = T - 1
# #ipdb.set_trace()
# latter_kv = kv[:, latter_frame_index]
# latter_kv = latter_kv.view(1, -1, 2, self.num_heads, self.head_dim) # 1(B T N) 2 H C
# latter_k, latter_v = latter_kv.unbind(2)
# #ipdb.set_trace()
# attn_bias = None
# if mask is not None:
# # mask = [m for m in mask for _ in range(T)]
# attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens(mask, [M] * (B*T))
# x = xformers.ops.memory_efficient_attention(q, latter_k, latter_v, p=self.attn_drop.p, attn_bias=attn_bias)
# #ipdb.set_trace()
# x = x.view(1, -1, C)
# x = self.proj(x)
# x = self.proj_drop(x)
# return x
# class SeqParallelMultiHeadCrossAttention(MultiHeadCrossAttention):
# def __init__(
# self,
# d_model,
# num_heads,
# attn_drop=0.0,
# proj_drop=0.0,
# ):
# super().__init__(d_model=d_model, num_heads=num_heads, attn_drop=attn_drop, proj_drop=proj_drop)
# def forward(self, x, cond, mask=None):
# # query/value: img tokens; key: condition; mask: if padding tokens
# sp_group = get_sequence_parallel_group()
# sp_size = dist.get_world_size(sp_group)
# B, SUB_N, C = x.shape
# N = SUB_N * sp_size
# # shape:
# # q, k, v: [B, SUB_N, NUM_HEADS, HEAD_DIM]
# q = self.q_linear(x).view(B, -1, self.num_heads, self.head_dim)
# kv = self.kv_linear(cond).view(B, -1, 2, self.num_heads, self.head_dim)
# k, v = kv.unbind(2)
# # apply all_to_all to gather sequence and split attention heads
# q = all_to_all(q, sp_group, scatter_dim=2, gather_dim=1)
# k = split_forward_gather_backward(k, get_sequence_parallel_group(), dim=2, grad_scale="down")
# v = split_forward_gather_backward(v, get_sequence_parallel_group(), dim=2, grad_scale="down")
# q = q.view(1, -1, self.num_heads // sp_size, self.head_dim)
# k = k.view(1, -1, self.num_heads // sp_size, self.head_dim)
# v = v.view(1, -1, self.num_heads // sp_size, self.head_dim)
# # compute attention
# attn_bias = None
# if mask is not None:
# attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask)
# x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
# # apply all to all to gather back attention heads and scatter sequence
# x = x.view(B, -1, self.num_heads // sp_size, self.head_dim)
# x = all_to_all(x, sp_group, scatter_dim=1, gather_dim=2)
# # apply output projection
# x = x.view(B, -1, C)
# x = self.proj(x)
# x = self.proj_drop(x)
# return x
# class FinalLayer(nn.Module):
# """
# The final layer of DiT.
# """
# def __init__(self, hidden_size, num_patch, out_channels):
# super().__init__()
# self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
# self.linear = nn.Linear(hidden_size, num_patch * out_channels, bias=True)
# self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
# def forward(self, x, c):
# shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
# x = modulate(self.norm_final, x, shift, scale)
# x = self.linear(x)
# return x
# class T2IFinalLayer(nn.Module):
# """
# The final layer of PixArt.
# """
# def __init__(self, hidden_size, num_patch, out_channels):
# super().__init__()
# self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
# self.linear = nn.Linear(hidden_size, num_patch * out_channels, bias=True)
# self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size**0.5)
# self.out_channels = out_channels
# def forward(self, x, t):
# shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1)
# x = t2i_modulate(self.norm_final(x), shift, scale)
# x = self.linear(x)
# return x
# # ===============================================
# # Embedding Layers for Timesteps and Class Labels
# # ===============================================
# class TimestepEmbedder(nn.Module):
# """
# Embeds scalar timesteps into vector representations.
# """
# def __init__(self, hidden_size, frequency_embedding_size=256):
# super().__init__()
# self.mlp = nn.Sequential(
# nn.Linear(frequency_embedding_size, hidden_size, bias=True),
# nn.SiLU(),
# nn.Linear(hidden_size, hidden_size, bias=True),
# )
# self.frequency_embedding_size = frequency_embedding_size
# @staticmethod
# def timestep_embedding(t, dim, max_period=10000):
# """
# Create sinusoidal timestep embeddings.
# :param t: a 1-D Tensor of N indices, one per batch element.
# These may be fractional.
# :param dim: the dimension of the output.
# :param max_period: controls the minimum frequency of the embeddings.
# :return: an (N, D) Tensor of positional embeddings.
# """
# # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
# half = dim // 2
# freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half)
# freqs = freqs.to(device=t.device)
# args = t[:, None].float() * freqs[None]
# embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
# if dim % 2:
# embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
# return embedding
# def forward(self, t, dtype):
# t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
# if t_freq.dtype != dtype:
# t_freq = t_freq.to(dtype)
# t_emb = self.mlp(t_freq)
# return t_emb
# class LabelEmbedder(nn.Module):
# """
# Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
# """
# def __init__(self, num_classes, hidden_size, dropout_prob):
# super().__init__()
# use_cfg_embedding = dropout_prob > 0
# self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
# self.num_classes = num_classes
# self.dropout_prob = dropout_prob
# def token_drop(self, labels, force_drop_ids=None):
# """
# Drops labels to enable classifier-free guidance.
# """
# if force_drop_ids is None:
# drop_ids = torch.rand(labels.shape[0]).cuda() < self.dropout_prob
# else:
# drop_ids = force_drop_ids == 1
# labels = torch.where(drop_ids, self.num_classes, labels)
# return labels
# def forward(self, labels, train, force_drop_ids=None):
# use_dropout = self.dropout_prob > 0
# if (train and use_dropout) or (force_drop_ids is not None):
# labels = self.token_drop(labels, force_drop_ids)
# return self.embedding_table(labels)
# class SizeEmbedder(TimestepEmbedder):
# """
# Embeds scalar timesteps into vector representations.
# """
# def __init__(self, hidden_size, frequency_embedding_size=256):
# super().__init__(hidden_size=hidden_size, frequency_embedding_size=frequency_embedding_size)
# self.mlp = nn.Sequential(
# nn.Linear(frequency_embedding_size, hidden_size, bias=True),
# nn.SiLU(),
# nn.Linear(hidden_size, hidden_size, bias=True),
# )
# self.frequency_embedding_size = frequency_embedding_size
# self.outdim = hidden_size
# def forward(self, s, bs):
# if s.ndim == 1:
# s = s[:, None]
# assert s.ndim == 2
# if s.shape[0] != bs:
# s = s.repeat(bs // s.shape[0], 1)
# assert s.shape[0] == bs
# b, dims = s.shape[0], s.shape[1]
# s = rearrange(s, "b d -> (b d)")
# s_freq = self.timestep_embedding(s, self.frequency_embedding_size).to(self.dtype)
# s_emb = self.mlp(s_freq)
# s_emb = rearrange(s_emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim)
# return s_emb
# @property
# def dtype(self):
# return next(self.parameters()).dtype
# class CaptionEmbedder(nn.Module):
# """
# Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
# """
# def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate="tanh"), token_num=120):
# super().__init__()
# self.y_proj = Mlp(
# in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer, drop=0
# )
# self.register_buffer("y_embedding", nn.Parameter(torch.randn(token_num, in_channels) / in_channels**0.5))
# self.uncond_prob = uncond_prob
# def token_drop(self, caption, force_drop_ids=None):
# """
# Drops labels to enable classifier-free guidance.
# """
# if force_drop_ids is None:
# drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob
# else:
# drop_ids = force_drop_ids == 1
# caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption)
# return caption
# def forward(self, caption, train, force_drop_ids=None):
# if train:
# assert caption.shape[2:] == self.y_embedding.shape
# use_dropout = self.uncond_prob > 0
# if (train and use_dropout) or (force_drop_ids is not None):
# caption = self.token_drop(caption, force_drop_ids)
# caption = self.y_proj(caption)
# return caption
# # ===============================================
# # Sine/Cosine Positional Embedding Functions
# # ===============================================
# # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
# def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, scale=1.0, base_size=None):
# """
# grid_size: int of the grid height and width
# return:
# pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
# """
# if not isinstance(grid_size, tuple):
# grid_size = (grid_size, grid_size)
# grid_h = np.arange(grid_size[0], dtype=np.float32) / scale
# grid_w = np.arange(grid_size[1], dtype=np.float32) / scale
# if base_size is not None:
# grid_h *= base_size / grid_size[0]
# grid_w *= base_size / grid_size[1]
# grid = np.meshgrid(grid_w, grid_h) # here w goes first
# grid = np.stack(grid, axis=0)
# grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
# pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
# if cls_token and extra_tokens > 0:
# pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
# return pos_embed
# def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
# assert embed_dim % 2 == 0
# # use half of dimensions to encode grid_h
# emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
# emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
# emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
# return emb
# def get_1d_sincos_pos_embed(embed_dim, length, scale=1.0):
# pos = np.arange(0, length)[..., None] / scale
# return get_1d_sincos_pos_embed_from_grid(embed_dim, pos)
# def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
# """
# embed_dim: output dimension for each position
# pos: a list of positions to be encoded: size (M,)
# out: (M, D)
# """
# assert embed_dim % 2 == 0
# omega = np.arange(embed_dim // 2, dtype=np.float64)
# omega /= embed_dim / 2.0
# omega = 1.0 / 10000**omega # (D/2,)
# pos = pos.reshape(-1) # (M,)
# out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
# emb_sin = np.sin(out) # (M, D/2)
# emb_cos = np.cos(out) # (M, D/2)
# emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
# return emb
utils_data/opensora/models/layers/timm_uvit.py
0 → 100644
View file @
1f5da520
# code from timm 0.3.2
import
torch
import
torch.nn
as
nn
import
math
import
warnings
def
_no_grad_trunc_normal_
(
tensor
,
mean
,
std
,
a
,
b
):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def
norm_cdf
(
x
):
# Computes standard normal cumulative distribution function
return
(
1.
+
math
.
erf
(
x
/
math
.
sqrt
(
2.
)))
/
2.
if
(
mean
<
a
-
2
*
std
)
or
(
mean
>
b
+
2
*
std
):
warnings
.
warn
(
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect."
,
stacklevel
=
2
)
with
torch
.
no_grad
():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l
=
norm_cdf
((
a
-
mean
)
/
std
)
u
=
norm_cdf
((
b
-
mean
)
/
std
)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor
.
uniform_
(
2
*
l
-
1
,
2
*
u
-
1
)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor
.
erfinv_
()
# Transform to proper mean, std
tensor
.
mul_
(
std
*
math
.
sqrt
(
2.
))
tensor
.
add_
(
mean
)
# Clamp to ensure it's in the proper range
tensor
.
clamp_
(
min
=
a
,
max
=
b
)
return
tensor
def
trunc_normal_
(
tensor
,
mean
=
0.
,
std
=
1.
,
a
=-
2.
,
b
=
2.
):
# type: (Tensor, float, float, float, float) -> Tensor
r
"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
"""
return
_no_grad_trunc_normal_
(
tensor
,
mean
,
std
,
a
,
b
)
def
drop_path
(
x
,
drop_prob
:
float
=
0.
,
training
:
bool
=
False
):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if
drop_prob
==
0.
or
not
training
:
return
x
keep_prob
=
1
-
drop_prob
shape
=
(
x
.
shape
[
0
],)
+
(
1
,)
*
(
x
.
ndim
-
1
)
# work with diff dim tensors, not just 2D ConvNets
random_tensor
=
keep_prob
+
torch
.
rand
(
shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
random_tensor
.
floor_
()
# binarize
output
=
x
.
div
(
keep_prob
)
*
random_tensor
return
output
class
DropPath
(
nn
.
Module
):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def
__init__
(
self
,
drop_prob
=
None
):
super
(
DropPath
,
self
).
__init__
()
self
.
drop_prob
=
drop_prob
def
forward
(
self
,
x
):
return
drop_path
(
x
,
self
.
drop_prob
,
self
.
training
)
class
Mlp
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
act_layer
=
nn
.
GELU
,
drop
=
0.
):
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
)
self
.
act
=
act_layer
()
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
)
self
.
drop
=
nn
.
Dropout
(
drop
)
def
forward
(
self
,
x
):
x
=
self
.
fc1
(
x
)
x
=
self
.
act
(
x
)
x
=
self
.
drop
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
drop
(
x
)
return
x
\ No newline at end of file
utils_data/opensora/models/pixart/__init__.py
0 → 100644
View file @
1f5da520
from
.pixart
import
PixArt
,
PixArt_XL_2
utils_data/opensora/models/pixart/__pycache__/__init__.cpython-39.pyc
0 → 100644
View file @
1f5da520
File added
utils_data/opensora/models/pixart/__pycache__/pixart.cpython-39.pyc
0 → 100644
View file @
1f5da520
File added
utils_data/opensora/models/pixart/pixart.py
0 → 100644
View file @
1f5da520
# Adapted from PixArt
#
# Copyright (C) 2023 PixArt-alpha/PixArt-alpha
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# PixArt: https://github.com/PixArt-alpha/PixArt-alpha
# DiT: https://github.com/facebookresearch/DiT/tree/main
# --------------------------------------------------------
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
einops
import
rearrange
from
timm.models.layers
import
DropPath
from
timm.models.vision_transformer
import
Mlp
# from .builder import MODELS
from
opensora.acceleration.checkpoint
import
auto_grad_checkpoint
from
opensora.models.layers.blocks
import
(
Attention
,
CaptionEmbedder
,
MultiHeadCrossAttention
,
PatchEmbed3D
,
SeqParallelAttention
,
SeqParallelMultiHeadCrossAttention
,
SizeEmbedder
,
T2IFinalLayer
,
TimestepEmbedder
,
approx_gelu
,
get_1d_sincos_pos_embed
,
get_2d_sincos_pos_embed
,
get_layernorm
,
t2i_modulate
,
)
from
opensora.registry
import
MODELS
from
opensora.utils.ckpt_utils
import
load_checkpoint
class
PixArtBlock
(
nn
.
Module
):
"""
A PixArt block with adaptive layer norm (adaLN-single) conditioning.
"""
def
__init__
(
self
,
hidden_size
,
num_heads
,
mlp_ratio
=
4.0
,
drop_path
=
0.0
,
enable_flashattn
=
False
,
enable_layernorm_kernel
=
False
,
enable_sequence_parallelism
=
False
,
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
enable_flashattn
=
enable_flashattn
self
.
_enable_sequence_parallelism
=
enable_sequence_parallelism
if
enable_sequence_parallelism
:
self
.
attn_cls
=
SeqParallelAttention
self
.
mha_cls
=
SeqParallelMultiHeadCrossAttention
else
:
self
.
attn_cls
=
Attention
self
.
mha_cls
=
MultiHeadCrossAttention
self
.
norm1
=
get_layernorm
(
hidden_size
,
eps
=
1e-6
,
affine
=
False
,
use_kernel
=
enable_layernorm_kernel
)
self
.
attn
=
self
.
attn_cls
(
hidden_size
,
num_heads
=
num_heads
,
qkv_bias
=
True
,
enable_flashattn
=
enable_flashattn
,
)
self
.
cross_attn
=
self
.
mha_cls
(
hidden_size
,
num_heads
)
self
.
norm2
=
get_layernorm
(
hidden_size
,
eps
=
1e-6
,
affine
=
False
,
use_kernel
=
enable_layernorm_kernel
)
self
.
mlp
=
Mlp
(
in_features
=
hidden_size
,
hidden_features
=
int
(
hidden_size
*
mlp_ratio
),
act_layer
=
approx_gelu
,
drop
=
0
)
self
.
drop_path
=
DropPath
(
drop_path
)
if
drop_path
>
0.0
else
nn
.
Identity
()
self
.
scale_shift_table
=
nn
.
Parameter
(
torch
.
randn
(
6
,
hidden_size
)
/
hidden_size
**
0.5
)
def
forward
(
self
,
x
,
y
,
t
,
mask
=
None
):
B
,
N
,
C
=
x
.
shape
shift_msa
,
scale_msa
,
gate_msa
,
shift_mlp
,
scale_mlp
,
gate_mlp
=
(
self
.
scale_shift_table
[
None
]
+
t
.
reshape
(
B
,
6
,
-
1
)
).
chunk
(
6
,
dim
=
1
)
x
=
x
+
self
.
drop_path
(
gate_msa
*
self
.
attn
(
t2i_modulate
(
self
.
norm1
(
x
),
shift_msa
,
scale_msa
)).
reshape
(
B
,
N
,
C
))
x
=
x
+
self
.
cross_attn
(
x
,
y
,
mask
)
x
=
x
+
self
.
drop_path
(
gate_mlp
*
self
.
mlp
(
t2i_modulate
(
self
.
norm2
(
x
),
shift_mlp
,
scale_mlp
)))
return
x
@
MODELS
.
register_module
()
class
PixArt
(
nn
.
Module
):
"""
Diffusion model with a Transformer backbone.
"""
def
__init__
(
self
,
input_size
=
(
1
,
32
,
32
),
in_channels
=
4
,
patch_size
=
(
1
,
2
,
2
),
hidden_size
=
1152
,
depth
=
28
,
num_heads
=
16
,
mlp_ratio
=
4.0
,
class_dropout_prob
=
0.1
,
pred_sigma
=
True
,
drop_path
:
float
=
0.0
,
no_temporal_pos_emb
=
False
,
caption_channels
=
4096
,
model_max_length
=
120
,
dtype
=
torch
.
float32
,
freeze
=
None
,
space_scale
=
1.0
,
time_scale
=
1.0
,
enable_flashattn
=
False
,
enable_layernorm_kernel
=
False
,
):
super
().
__init__
()
self
.
pred_sigma
=
pred_sigma
self
.
in_channels
=
in_channels
self
.
out_channels
=
in_channels
*
2
if
pred_sigma
else
in_channels
self
.
hidden_size
=
hidden_size
self
.
patch_size
=
patch_size
self
.
input_size
=
input_size
num_patches
=
np
.
prod
([
input_size
[
i
]
//
patch_size
[
i
]
for
i
in
range
(
3
)])
self
.
num_patches
=
num_patches
self
.
num_temporal
=
input_size
[
0
]
//
patch_size
[
0
]
self
.
num_spatial
=
num_patches
//
self
.
num_temporal
self
.
base_size
=
int
(
np
.
sqrt
(
self
.
num_spatial
))
self
.
num_heads
=
num_heads
self
.
dtype
=
dtype
self
.
no_temporal_pos_emb
=
no_temporal_pos_emb
self
.
depth
=
depth
self
.
mlp_ratio
=
mlp_ratio
self
.
enable_flashattn
=
enable_flashattn
self
.
enable_layernorm_kernel
=
enable_layernorm_kernel
self
.
space_scale
=
space_scale
self
.
time_scale
=
time_scale
self
.
x_embedder
=
PatchEmbed3D
(
patch_size
,
in_channels
,
hidden_size
)
self
.
t_embedder
=
TimestepEmbedder
(
hidden_size
)
self
.
t_block
=
nn
.
Sequential
(
nn
.
SiLU
(),
nn
.
Linear
(
hidden_size
,
6
*
hidden_size
,
bias
=
True
))
self
.
y_embedder
=
CaptionEmbedder
(
in_channels
=
caption_channels
,
hidden_size
=
hidden_size
,
uncond_prob
=
class_dropout_prob
,
act_layer
=
approx_gelu
,
token_num
=
model_max_length
,
)
self
.
register_buffer
(
"pos_embed"
,
self
.
get_spatial_pos_embed
())
self
.
register_buffer
(
"pos_embed_temporal"
,
self
.
get_temporal_pos_embed
())
drop_path
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path
,
depth
)]
# stochastic depth decay rule
self
.
blocks
=
nn
.
ModuleList
(
[
PixArtBlock
(
hidden_size
,
num_heads
,
mlp_ratio
=
mlp_ratio
,
drop_path
=
drop_path
[
i
],
enable_flashattn
=
enable_flashattn
,
enable_layernorm_kernel
=
enable_layernorm_kernel
,
)
for
i
in
range
(
depth
)
]
)
self
.
final_layer
=
T2IFinalLayer
(
hidden_size
,
np
.
prod
(
self
.
patch_size
),
self
.
out_channels
)
self
.
initialize_weights
()
if
freeze
is
not
None
:
assert
freeze
in
[
"text"
]
if
freeze
==
"text"
:
self
.
freeze_text
()
def
forward
(
self
,
x
,
timestep
,
y
,
mask
=
None
):
"""
Forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) tensor of class labels
"""
x
=
x
.
to
(
self
.
dtype
)
timestep
=
timestep
.
to
(
self
.
dtype
)
y
=
y
.
to
(
self
.
dtype
)
# embedding
x
=
self
.
x_embedder
(
x
)
# (B, N, D)
x
=
rearrange
(
x
,
"b (t s) d -> b t s d"
,
t
=
self
.
num_temporal
,
s
=
self
.
num_spatial
)
x
=
x
+
self
.
pos_embed
if
not
self
.
no_temporal_pos_emb
:
x
=
rearrange
(
x
,
"b t s d -> b s t d"
)
x
=
x
+
self
.
pos_embed_temporal
x
=
rearrange
(
x
,
"b s t d -> b (t s) d"
)
else
:
x
=
rearrange
(
x
,
"b t s d -> b (t s) d"
)
t
=
self
.
t_embedder
(
timestep
,
dtype
=
x
.
dtype
)
# (N, D)
t0
=
self
.
t_block
(
t
)
y
=
self
.
y_embedder
(
y
,
self
.
training
)
# (N, 1, L, D)
if
mask
is
not
None
:
if
mask
.
shape
[
0
]
!=
y
.
shape
[
0
]:
mask
=
mask
.
repeat
(
y
.
shape
[
0
]
//
mask
.
shape
[
0
],
1
)
mask
=
mask
.
squeeze
(
1
).
squeeze
(
1
)
y
=
y
.
squeeze
(
1
).
masked_select
(
mask
.
unsqueeze
(
-
1
)
!=
0
).
view
(
1
,
-
1
,
x
.
shape
[
-
1
])
y_lens
=
mask
.
sum
(
dim
=
1
).
tolist
()
else
:
y_lens
=
[
y
.
shape
[
2
]]
*
y
.
shape
[
0
]
y
=
y
.
squeeze
(
1
).
view
(
1
,
-
1
,
x
.
shape
[
-
1
])
# blocks
for
block
in
self
.
blocks
:
x
=
auto_grad_checkpoint
(
block
,
x
,
y
,
t0
,
y_lens
)
# final process
x
=
self
.
final_layer
(
x
,
t
)
# (N, T, patch_size ** 2 * out_channels)
x
=
self
.
unpatchify
(
x
)
# (N, out_channels, H, W)
# cast to float32 for better accuracy
x
=
x
.
to
(
torch
.
float32
)
return
x
def
unpatchify
(
self
,
x
):
c
=
self
.
out_channels
t
,
h
,
w
=
[
self
.
input_size
[
i
]
//
self
.
patch_size
[
i
]
for
i
in
range
(
3
)]
pt
,
ph
,
pw
=
self
.
patch_size
x
=
x
.
reshape
(
shape
=
(
x
.
shape
[
0
],
t
,
h
,
w
,
pt
,
ph
,
pw
,
c
))
x
=
rearrange
(
x
,
"n t h w r p q c -> n c t r h p w q"
)
imgs
=
x
.
reshape
(
shape
=
(
x
.
shape
[
0
],
c
,
t
*
pt
,
h
*
ph
,
w
*
pw
))
return
imgs
def
get_spatial_pos_embed
(
self
,
grid_size
=
None
):
if
grid_size
is
None
:
grid_size
=
self
.
input_size
[
1
:]
pos_embed
=
get_2d_sincos_pos_embed
(
self
.
hidden_size
,
(
grid_size
[
0
]
//
self
.
patch_size
[
1
],
grid_size
[
1
]
//
self
.
patch_size
[
2
]),
scale
=
self
.
space_scale
,
base_size
=
self
.
base_size
,
)
pos_embed
=
torch
.
from_numpy
(
pos_embed
).
float
().
unsqueeze
(
0
).
requires_grad_
(
False
)
return
pos_embed
def
get_temporal_pos_embed
(
self
):
pos_embed
=
get_1d_sincos_pos_embed
(
self
.
hidden_size
,
self
.
input_size
[
0
]
//
self
.
patch_size
[
0
],
scale
=
self
.
time_scale
,
)
pos_embed
=
torch
.
from_numpy
(
pos_embed
).
float
().
unsqueeze
(
0
).
requires_grad_
(
False
)
return
pos_embed
def
freeze_text
(
self
):
for
n
,
p
in
self
.
named_parameters
():
if
"cross_attn"
in
n
:
p
.
requires_grad
=
False
def
initialize_weights
(
self
):
# Initialize transformer layers:
def
_basic_init
(
module
):
if
isinstance
(
module
,
nn
.
Linear
):
torch
.
nn
.
init
.
xavier_uniform_
(
module
.
weight
)
if
module
.
bias
is
not
None
:
nn
.
init
.
constant_
(
module
.
bias
,
0
)
self
.
apply
(
_basic_init
)
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w
=
self
.
x_embedder
.
proj
.
weight
.
data
nn
.
init
.
xavier_uniform_
(
w
.
view
([
w
.
shape
[
0
],
-
1
]))
# Initialize timestep embedding MLP:
nn
.
init
.
normal_
(
self
.
t_embedder
.
mlp
[
0
].
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
t_embedder
.
mlp
[
2
].
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
t_block
[
1
].
weight
,
std
=
0.02
)
# Initialize caption embedding MLP:
nn
.
init
.
normal_
(
self
.
y_embedder
.
y_proj
.
fc1
.
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
y_embedder
.
y_proj
.
fc2
.
weight
,
std
=
0.02
)
# Zero-out adaLN modulation layers in PixArt blocks:
for
block
in
self
.
blocks
:
nn
.
init
.
constant_
(
block
.
cross_attn
.
proj
.
weight
,
0
)
nn
.
init
.
constant_
(
block
.
cross_attn
.
proj
.
bias
,
0
)
# Zero-out output layers:
nn
.
init
.
constant_
(
self
.
final_layer
.
linear
.
weight
,
0
)
nn
.
init
.
constant_
(
self
.
final_layer
.
linear
.
bias
,
0
)
@
MODELS
.
register_module
()
class
PixArtMS
(
PixArt
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
assert
self
.
hidden_size
%
3
==
0
,
"hidden_size must be divisible by 3"
self
.
csize_embedder
=
SizeEmbedder
(
self
.
hidden_size
//
3
)
self
.
ar_embedder
=
SizeEmbedder
(
self
.
hidden_size
//
3
)
def
forward
(
self
,
x
,
timestep
,
y
,
mask
=
None
,
data_info
=
None
):
"""
Forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) tensor of class labels
"""
x
=
x
.
to
(
self
.
dtype
)
timestep
=
timestep
.
to
(
self
.
dtype
)
y
=
y
.
to
(
self
.
dtype
)
c_size
=
data_info
[
"hw"
]
ar
=
data_info
[
"ar"
]
pos_embed
=
self
.
get_spatial_pos_embed
((
x
.
shape
[
-
2
],
x
.
shape
[
-
1
])).
to
(
x
.
dtype
)
# embedding
x
=
self
.
x_embedder
(
x
)
# (B, N, D)
x
=
rearrange
(
x
,
"b (t s) d -> b t s d"
,
t
=
self
.
num_temporal
,
s
=
self
.
num_spatial
)
x
=
x
+
pos_embed
.
to
(
x
.
device
)
if
not
self
.
no_temporal_pos_emb
:
x
=
rearrange
(
x
,
"b t s d -> b s t d"
)
x
=
x
+
self
.
pos_embed_temporal
x
=
rearrange
(
x
,
"b s t d -> b (t s) d"
)
else
:
x
=
rearrange
(
x
,
"b t s d -> b (t s) d"
)
t
=
self
.
t_embedder
(
timestep
,
dtype
=
x
.
dtype
)
# (N, D)
B
=
x
.
shape
[
0
]
csize
=
self
.
csize_embedder
(
c_size
,
B
)
ar
=
self
.
ar_embedder
(
ar
,
B
)
t
=
t
+
torch
.
cat
([
csize
,
ar
],
dim
=
1
)
t0
=
self
.
t_block
(
t
)
y
=
self
.
y_embedder
(
y
,
self
.
training
)
# (N, 1, L, D)
if
mask
is
not
None
:
if
mask
.
shape
[
0
]
!=
y
.
shape
[
0
]:
mask
=
mask
.
repeat
(
y
.
shape
[
0
]
//
mask
.
shape
[
0
],
1
)
mask
=
mask
.
squeeze
(
1
).
squeeze
(
1
)
y
=
y
.
squeeze
(
1
).
masked_select
(
mask
.
unsqueeze
(
-
1
)
!=
0
).
view
(
1
,
-
1
,
x
.
shape
[
-
1
])
y_lens
=
mask
.
sum
(
dim
=
1
).
tolist
()
else
:
y_lens
=
[
y
.
shape
[
2
]]
*
y
.
shape
[
0
]
y
=
y
.
squeeze
(
1
).
view
(
1
,
-
1
,
x
.
shape
[
-
1
])
# blocks
for
block
in
self
.
blocks
:
x
=
block
(
x
,
y
,
t0
,
y_lens
)
# final process
x
=
self
.
final_layer
(
x
,
t
)
# (N, T, patch_size ** 2 * out_channels)
x
=
self
.
unpatchify
(
x
)
# (N, out_channels, H, W)
# cast to float32 for better accuracy
x
=
x
.
to
(
torch
.
float32
)
return
x
@
MODELS
.
register_module
(
"PixArt-XL/2"
)
def
PixArt_XL_2
(
from_pretrained
=
None
,
**
kwargs
):
model
=
PixArt
(
depth
=
28
,
hidden_size
=
1152
,
patch_size
=
(
1
,
2
,
2
),
num_heads
=
16
,
**
kwargs
)
if
from_pretrained
is
not
None
:
load_checkpoint
(
model
,
from_pretrained
)
return
model
@
MODELS
.
register_module
(
"PixArtMS-XL/2"
)
def
PixArtMS_XL_2
(
from_pretrained
=
None
,
**
kwargs
):
model
=
PixArtMS
(
depth
=
28
,
hidden_size
=
1152
,
patch_size
=
(
1
,
2
,
2
),
num_heads
=
16
,
**
kwargs
)
if
from_pretrained
is
not
None
:
load_checkpoint
(
model
,
from_pretrained
)
return
model
utils_data/opensora/models/stdit/__init__.py
0 → 100644
View file @
1f5da520
from
.stdit
import
STDiT
#from .stdit_ftf import STDiT_FTF
from
.stdit_qknorm_rope
import
STDiT_QKNorm_RoPE
#from .stdit_denseadd import STDiT_DenseAdd
#from .stdit_densecat import STDiT_DenseCat
#from .stdit_it2v import STDiT_IT2V
#from .stdit_densewmean import STDiT_DenseWmean
#from .stdit_longshort import STDiT_LS
#from .stdit_ttca import STDiT_TTCA
#from .stdit_densecat_norm import STDiT_DenseCatNorm
#from .ustdit import USTDiT
#from .stdit_densesecat import STDiT_DenseSECat
#from .stdit_densecat_mmdit import STDiT_DenseMM
from
.stdit_mmdit
import
STDiT_MMDiT
from
.stdit_freq
import
STDiT_freq
#from .stdit_densecat_multipos import STDiT_DenseCatMpos
#from .stdit_mmdit_nocross import STDiT_MMDiT_Nocross
#from .stdit_mmdit_tempsplit import STDiT_MMDiT_TempSplit
#from .stdit_mmdit_qk import STDiT_MMDiTQK
#from .stdit_mmdit_small import STDiT_MMDiT_Small
utils_data/opensora/models/stdit/__pycache__/__init__.cpython-39.pyc
0 → 100644
View file @
1f5da520
File added
utils_data/opensora/models/stdit/__pycache__/stdit.cpython-39.pyc
0 → 100644
View file @
1f5da520
File added
utils_data/opensora/models/stdit/__pycache__/stdit_freq.cpython-39.pyc
0 → 100644
View file @
1f5da520
File added
utils_data/opensora/models/stdit/__pycache__/stdit_mmdit.cpython-39.pyc
0 → 100644
View file @
1f5da520
File added
utils_data/opensora/models/stdit/__pycache__/stdit_qknorm_rope.cpython-39.pyc
0 → 100644
View file @
1f5da520
File added
utils_data/opensora/models/stdit/stdit.py
0 → 100644
View file @
1f5da520
import
numpy
as
np
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
from
einops
import
rearrange
from
timm.models.layers
import
DropPath
from
timm.models.vision_transformer
import
Mlp
from
opensora.acceleration.checkpoint
import
auto_grad_checkpoint
from
opensora.acceleration.communications
import
gather_forward_split_backward
,
split_forward_gather_backward
from
opensora.acceleration.parallel_states
import
get_sequence_parallel_group
from
opensora.models.layers.blocks
import
(
Attention
,
CaptionEmbedder
,
MultiHeadCrossAttention
,
PatchEmbed3D
,
SeqParallelAttention
,
SeqParallelMultiHeadCrossAttention
,
T2IFinalLayer
,
TimestepEmbedder
,
approx_gelu
,
get_1d_sincos_pos_embed
,
get_2d_sincos_pos_embed
,
get_layernorm
,
t2i_modulate
,
)
from
opensora.registry
import
MODELS
from
opensora.utils.ckpt_utils
import
load_checkpoint
# import ipdb
class
STDiTBlock
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
num_heads
,
d_s
=
None
,
d_t
=
None
,
mlp_ratio
=
4.0
,
drop_path
=
0.0
,
enable_flashattn
=
False
,
enable_layernorm_kernel
=
False
,
enable_sequence_parallelism
=
False
,
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
enable_flashattn
=
enable_flashattn
self
.
_enable_sequence_parallelism
=
enable_sequence_parallelism
if
enable_sequence_parallelism
:
self
.
attn_cls
=
SeqParallelAttention
self
.
mha_cls
=
SeqParallelMultiHeadCrossAttention
else
:
# here
self
.
attn_cls
=
Attention
self
.
mha_cls
=
MultiHeadCrossAttention
self
.
norm1
=
get_layernorm
(
hidden_size
,
eps
=
1e-6
,
affine
=
False
,
use_kernel
=
enable_layernorm_kernel
)
self
.
attn
=
self
.
attn_cls
(
hidden_size
,
num_heads
=
num_heads
,
qkv_bias
=
True
,
enable_flashattn
=
enable_flashattn
,
)
self
.
cross_attn
=
self
.
mha_cls
(
hidden_size
,
num_heads
)
self
.
norm2
=
get_layernorm
(
hidden_size
,
eps
=
1e-6
,
affine
=
False
,
use_kernel
=
enable_layernorm_kernel
)
self
.
mlp
=
Mlp
(
in_features
=
hidden_size
,
hidden_features
=
int
(
hidden_size
*
mlp_ratio
),
act_layer
=
approx_gelu
,
drop
=
0
)
self
.
drop_path
=
DropPath
(
drop_path
)
if
drop_path
>
0.0
else
nn
.
Identity
()
self
.
scale_shift_table
=
nn
.
Parameter
(
torch
.
randn
(
6
,
hidden_size
)
/
hidden_size
**
0.5
)
# temporal attention
self
.
d_s
=
d_s
self
.
d_t
=
d_t
if
self
.
_enable_sequence_parallelism
:
sp_size
=
dist
.
get_world_size
(
get_sequence_parallel_group
())
# make sure d_t is divisible by sp_size
assert
d_t
%
sp_size
==
0
self
.
d_t
=
d_t
//
sp_size
self
.
attn_temp
=
self
.
attn_cls
(
hidden_size
,
num_heads
=
num_heads
,
qkv_bias
=
True
,
enable_flashattn
=
self
.
enable_flashattn
,
)
def
forward
(
self
,
x
,
y
,
t
,
mask
=
None
,
tpe
=
None
):
B
,
N
,
C
=
x
.
shape
# ipdb.set_trace()
shift_msa
,
scale_msa
,
gate_msa
,
shift_mlp
,
scale_mlp
,
gate_mlp
=
(
self
.
scale_shift_table
[
None
]
+
t
.
reshape
(
B
,
6
,
-
1
)
).
chunk
(
6
,
dim
=
1
)
x_m
=
t2i_modulate
(
self
.
norm1
(
x
),
shift_msa
,
scale_msa
)
# spatial branch
x_s
=
rearrange
(
x_m
,
"B (T S) C -> (B T) S C"
,
T
=
self
.
d_t
,
S
=
self
.
d_s
)
x_s
=
self
.
attn
(
x_s
)
x_s
=
rearrange
(
x_s
,
"(B T) S C -> B (T S) C"
,
T
=
self
.
d_t
,
S
=
self
.
d_s
)
x
=
x
+
self
.
drop_path
(
gate_msa
*
x_s
)
# temporal branch
x_t
=
rearrange
(
x
,
"B (T S) C -> (B S) T C"
,
T
=
self
.
d_t
,
S
=
self
.
d_s
)
if
tpe
is
not
None
:
x_t
=
x_t
+
tpe
x_t
=
self
.
attn_temp
(
x_t
)
x_t
=
rearrange
(
x_t
,
"(B S) T C -> B (T S) C"
,
T
=
self
.
d_t
,
S
=
self
.
d_s
)
x
=
x
+
self
.
drop_path
(
gate_msa
*
x_t
)
# cross attn
x
=
x
+
self
.
cross_attn
(
x
,
y
,
mask
)
# mlp
x
=
x
+
self
.
drop_path
(
gate_mlp
*
self
.
mlp
(
t2i_modulate
(
self
.
norm2
(
x
),
shift_mlp
,
scale_mlp
)))
return
x
@
MODELS
.
register_module
()
class
STDiT
(
nn
.
Module
):
def
__init__
(
self
,
input_size
=
(
1
,
32
,
32
),
in_channels
=
4
,
patch_size
=
(
1
,
2
,
2
),
hidden_size
=
1152
,
depth
=
28
,
num_heads
=
16
,
mlp_ratio
=
4.0
,
class_dropout_prob
=
0.1
,
pred_sigma
=
True
,
drop_path
=
0.0
,
no_temporal_pos_emb
=
False
,
caption_channels
=
4096
,
model_max_length
=
120
,
dtype
=
torch
.
float32
,
space_scale
=
1.0
,
time_scale
=
1.0
,
freeze
=
None
,
enable_flashattn
=
False
,
enable_layernorm_kernel
=
False
,
enable_sequence_parallelism
=
False
,
):
super
().
__init__
()
self
.
pred_sigma
=
pred_sigma
self
.
in_channels
=
in_channels
self
.
out_channels
=
in_channels
*
2
if
pred_sigma
else
in_channels
self
.
hidden_size
=
hidden_size
self
.
patch_size
=
patch_size
self
.
input_size
=
input_size
num_patches
=
np
.
prod
([
input_size
[
i
]
//
patch_size
[
i
]
for
i
in
range
(
3
)])
self
.
num_patches
=
num_patches
self
.
num_temporal
=
input_size
[
0
]
//
patch_size
[
0
]
self
.
num_spatial
=
num_patches
//
self
.
num_temporal
self
.
num_heads
=
num_heads
self
.
dtype
=
dtype
self
.
no_temporal_pos_emb
=
no_temporal_pos_emb
self
.
depth
=
depth
self
.
mlp_ratio
=
mlp_ratio
self
.
enable_flashattn
=
enable_flashattn
self
.
enable_layernorm_kernel
=
enable_layernorm_kernel
self
.
space_scale
=
space_scale
self
.
time_scale
=
time_scale
self
.
register_buffer
(
"pos_embed"
,
self
.
get_spatial_pos_embed
())
self
.
register_buffer
(
"pos_embed_temporal"
,
self
.
get_temporal_pos_embed
())
self
.
x_embedder
=
PatchEmbed3D
(
patch_size
,
in_channels
,
hidden_size
)
self
.
t_embedder
=
TimestepEmbedder
(
hidden_size
)
self
.
t_block
=
nn
.
Sequential
(
nn
.
SiLU
(),
nn
.
Linear
(
hidden_size
,
6
*
hidden_size
,
bias
=
True
))
self
.
y_embedder
=
CaptionEmbedder
(
in_channels
=
caption_channels
,
hidden_size
=
hidden_size
,
uncond_prob
=
class_dropout_prob
,
act_layer
=
approx_gelu
,
token_num
=
model_max_length
,
)
drop_path
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path
,
depth
)]
self
.
blocks
=
nn
.
ModuleList
(
[
STDiTBlock
(
self
.
hidden_size
,
self
.
num_heads
,
mlp_ratio
=
self
.
mlp_ratio
,
drop_path
=
drop_path
[
i
],
enable_flashattn
=
self
.
enable_flashattn
,
enable_layernorm_kernel
=
self
.
enable_layernorm_kernel
,
enable_sequence_parallelism
=
enable_sequence_parallelism
,
d_t
=
self
.
num_temporal
,
d_s
=
self
.
num_spatial
,
)
for
i
in
range
(
self
.
depth
)
]
)
self
.
final_layer
=
T2IFinalLayer
(
hidden_size
,
np
.
prod
(
self
.
patch_size
),
self
.
out_channels
)
# init model
self
.
initialize_weights
()
self
.
initialize_temporal
()
if
freeze
is
not
None
:
assert
freeze
in
[
"not_temporal"
,
"text"
]
if
freeze
==
"not_temporal"
:
self
.
freeze_not_temporal
()
elif
freeze
==
"text"
:
self
.
freeze_text
()
# sequence parallel related configs
self
.
enable_sequence_parallelism
=
enable_sequence_parallelism
if
enable_sequence_parallelism
:
self
.
sp_rank
=
dist
.
get_rank
(
get_sequence_parallel_group
())
else
:
self
.
sp_rank
=
None
def
forward
(
self
,
x
,
timestep
,
y
,
mask
=
None
):
"""
Forward pass of STDiT.
Args:
x (torch.Tensor): latent representation of video; of shape [B, C, T, H, W]
timestep (torch.Tensor): diffusion time steps; of shape [B]
y (torch.Tensor): representation of prompts; of shape [B, 1, N_token, C]
mask (torch.Tensor): mask for selecting prompt tokens; of shape [B, N_token]
Returns:
x (torch.Tensor): output latent representation; of shape [B, C, T, H, W]
"""
x
=
x
.
to
(
self
.
dtype
)
timestep
=
timestep
.
to
(
self
.
dtype
)
y
=
y
.
to
(
self
.
dtype
)
# embedding
x
=
self
.
x_embedder
(
x
)
# [B, N, C]
x
=
rearrange
(
x
,
"B (T S) C -> B T S C"
,
T
=
self
.
num_temporal
,
S
=
self
.
num_spatial
)
x
=
x
+
self
.
pos_embed
x
=
rearrange
(
x
,
"B T S C -> B (T S) C"
)
# shard over the sequence dim if sp is enabled
if
self
.
enable_sequence_parallelism
:
x
=
split_forward_gather_backward
(
x
,
get_sequence_parallel_group
(),
dim
=
1
,
grad_scale
=
"down"
)
t
=
self
.
t_embedder
(
timestep
,
dtype
=
x
.
dtype
)
# [B, C]
t0
=
self
.
t_block
(
t
)
# [B, C]
y
=
self
.
y_embedder
(
y
,
self
.
training
)
# [B, 1, N_token, C]
if
mask
is
not
None
:
if
mask
.
shape
[
0
]
!=
y
.
shape
[
0
]:
mask
=
mask
.
repeat
(
y
.
shape
[
0
]
//
mask
.
shape
[
0
],
1
)
mask
=
mask
.
squeeze
(
1
).
squeeze
(
1
)
y
=
y
.
squeeze
(
1
).
masked_select
(
mask
.
unsqueeze
(
-
1
)
!=
0
).
view
(
1
,
-
1
,
x
.
shape
[
-
1
])
y_lens
=
mask
.
sum
(
dim
=
1
).
tolist
()
else
:
y_lens
=
[
y
.
shape
[
2
]]
*
y
.
shape
[
0
]
y
=
y
.
squeeze
(
1
).
view
(
1
,
-
1
,
x
.
shape
[
-
1
])
# blocks
for
i
,
block
in
enumerate
(
self
.
blocks
):
if
i
==
0
:
if
self
.
enable_sequence_parallelism
:
tpe
=
torch
.
chunk
(
self
.
pos_embed_temporal
,
dist
.
get_world_size
(
get_sequence_parallel_group
()),
dim
=
1
)[
self
.
sp_rank
].
contiguous
()
else
:
tpe
=
self
.
pos_embed_temporal
else
:
tpe
=
None
x
=
auto_grad_checkpoint
(
block
,
x
,
y
,
t0
,
y_lens
,
tpe
)
if
self
.
enable_sequence_parallelism
:
x
=
gather_forward_split_backward
(
x
,
get_sequence_parallel_group
(),
dim
=
1
,
grad_scale
=
"up"
)
# x.shape: [B, N, C]
# final process
x
=
self
.
final_layer
(
x
,
t
)
# [B, N, C=T_p * H_p * W_p * C_out]
x
=
self
.
unpatchify
(
x
)
# [B, C_out, T, H, W]
# cast to float32 for better accuracy
x
=
x
.
to
(
torch
.
float32
)
return
x
def
unpatchify
(
self
,
x
):
"""
Args:
x (torch.Tensor): of shape [B, N, C]
Return:
x (torch.Tensor): of shape [B, C_out, T, H, W]
"""
N_t
,
N_h
,
N_w
=
[
self
.
input_size
[
i
]
//
self
.
patch_size
[
i
]
for
i
in
range
(
3
)]
T_p
,
H_p
,
W_p
=
self
.
patch_size
x
=
rearrange
(
x
,
"B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)"
,
N_t
=
N_t
,
N_h
=
N_h
,
N_w
=
N_w
,
T_p
=
T_p
,
H_p
=
H_p
,
W_p
=
W_p
,
C_out
=
self
.
out_channels
,
)
return
x
def
unpatchify_old
(
self
,
x
):
c
=
self
.
out_channels
t
,
h
,
w
=
[
self
.
input_size
[
i
]
//
self
.
patch_size
[
i
]
for
i
in
range
(
3
)]
pt
,
ph
,
pw
=
self
.
patch_size
x
=
x
.
reshape
(
shape
=
(
x
.
shape
[
0
],
t
,
h
,
w
,
pt
,
ph
,
pw
,
c
))
x
=
rearrange
(
x
,
"n t h w r p q c -> n c t r h p w q"
)
imgs
=
x
.
reshape
(
shape
=
(
x
.
shape
[
0
],
c
,
t
*
pt
,
h
*
ph
,
w
*
pw
))
return
imgs
def
get_spatial_pos_embed
(
self
,
grid_size
=
None
):
if
grid_size
is
None
:
grid_size
=
self
.
input_size
[
1
:]
pos_embed
=
get_2d_sincos_pos_embed
(
self
.
hidden_size
,
(
grid_size
[
0
]
//
self
.
patch_size
[
1
],
grid_size
[
1
]
//
self
.
patch_size
[
2
]),
scale
=
self
.
space_scale
,
)
pos_embed
=
torch
.
from_numpy
(
pos_embed
).
float
().
unsqueeze
(
0
).
requires_grad_
(
False
)
return
pos_embed
def
get_temporal_pos_embed
(
self
):
pos_embed
=
get_1d_sincos_pos_embed
(
self
.
hidden_size
,
self
.
input_size
[
0
]
//
self
.
patch_size
[
0
],
scale
=
self
.
time_scale
,
)
pos_embed
=
torch
.
from_numpy
(
pos_embed
).
float
().
unsqueeze
(
0
).
requires_grad_
(
False
)
return
pos_embed
def
freeze_not_temporal
(
self
):
for
n
,
p
in
self
.
named_parameters
():
if
"attn_temp"
not
in
n
:
p
.
requires_grad
=
False
def
freeze_text
(
self
):
for
n
,
p
in
self
.
named_parameters
():
if
"cross_attn"
in
n
:
p
.
requires_grad
=
False
def
initialize_temporal
(
self
):
for
block
in
self
.
blocks
:
nn
.
init
.
constant_
(
block
.
attn_temp
.
proj
.
weight
,
0
)
nn
.
init
.
constant_
(
block
.
attn_temp
.
proj
.
bias
,
0
)
def
initialize_weights
(
self
):
# Initialize transformer layers:
def
_basic_init
(
module
):
if
isinstance
(
module
,
nn
.
Linear
):
torch
.
nn
.
init
.
xavier_uniform_
(
module
.
weight
)
if
module
.
bias
is
not
None
:
nn
.
init
.
constant_
(
module
.
bias
,
0
)
self
.
apply
(
_basic_init
)
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w
=
self
.
x_embedder
.
proj
.
weight
.
data
nn
.
init
.
xavier_uniform_
(
w
.
view
([
w
.
shape
[
0
],
-
1
]))
# Initialize timestep embedding MLP:
nn
.
init
.
normal_
(
self
.
t_embedder
.
mlp
[
0
].
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
t_embedder
.
mlp
[
2
].
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
t_block
[
1
].
weight
,
std
=
0.02
)
# Initialize caption embedding MLP:
nn
.
init
.
normal_
(
self
.
y_embedder
.
y_proj
.
fc1
.
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
y_embedder
.
y_proj
.
fc2
.
weight
,
std
=
0.02
)
# Zero-out adaLN modulation layers in PixArt blocks:
for
block
in
self
.
blocks
:
nn
.
init
.
constant_
(
block
.
cross_attn
.
proj
.
weight
,
0
)
nn
.
init
.
constant_
(
block
.
cross_attn
.
proj
.
bias
,
0
)
# Zero-out output layers:
nn
.
init
.
constant_
(
self
.
final_layer
.
linear
.
weight
,
0
)
nn
.
init
.
constant_
(
self
.
final_layer
.
linear
.
bias
,
0
)
@
MODELS
.
register_module
(
"STDiT-XL/2"
)
def
STDiT_XL_2
(
from_pretrained
=
None
,
**
kwargs
):
model
=
STDiT
(
depth
=
28
,
hidden_size
=
1152
,
patch_size
=
(
1
,
2
,
2
),
num_heads
=
16
,
**
kwargs
)
if
from_pretrained
is
not
None
:
load_checkpoint
(
model
,
from_pretrained
)
return
model
utils_data/opensora/models/stdit/stdit_controlnet.py
0 → 100644
View file @
1f5da520
import
re
import
torch
import
torch.nn
as
nn
from
copy
import
deepcopy
from
torch
import
Tensor
from
torch.nn
import
Module
,
Linear
,
init
from
typing
import
Any
,
Mapping
import
sys
sys
.
path
.
append
(
"/home/test/Workspace/ruixie/Open-Sora"
)
from
einops
import
rearrange
# from diffusion.model.nets import PixArtMSBlock, PixArtMS, PixArt
from
opensora.models.stdit.stdit
import
STDiTBlock
,
STDiT
from
opensora.models.layers.blocks
import
(
Attention
,
CaptionEmbedder
,
MultiHeadCrossAttention
,
PatchEmbed3D
,
SeqParallelAttention
,
SeqParallelMultiHeadCrossAttention
,
T2IFinalLayer
,
TimestepEmbedder
,
approx_gelu
,
get_1d_sincos_pos_embed
,
get_2d_sincos_pos_embed
,
get_layernorm
,
t2i_modulate
,
)
from
opensora.acceleration.checkpoint
import
auto_grad_checkpoint
# The implementation of ControlNet-Half architrecture
# https://github.com/lllyasviel/ControlNet/discussions/188
class
ControlT2IDitBlockHalf
(
Module
):
def
__init__
(
self
,
base_block
:
STDiTBlock
,
block_index
:
0
)
->
None
:
super
().
__init__
()
self
.
copied_block
=
deepcopy
(
base_block
)
self
.
block_index
=
block_index
for
p
in
self
.
copied_block
.
parameters
():
p
.
requires_grad_
(
True
)
self
.
copied_block
.
load_state_dict
(
base_block
.
state_dict
())
self
.
copied_block
.
train
()
self
.
hidden_size
=
hidden_size
=
base_block
.
hidden_size
if
self
.
block_index
==
0
:
self
.
before_proj
=
Linear
(
hidden_size
,
hidden_size
)
init
.
zeros_
(
self
.
before_proj
.
weight
)
init
.
zeros_
(
self
.
before_proj
.
bias
)
self
.
after_proj
=
Linear
(
hidden_size
,
hidden_size
)
init
.
zeros_
(
self
.
after_proj
.
weight
)
init
.
zeros_
(
self
.
after_proj
.
bias
)
def
forward
(
self
,
x
,
y
,
t0
,
y_lens
,
c
,
tpe
):
if
self
.
block_index
==
0
:
# the first block
c
=
self
.
before_proj
(
c
)
c
=
self
.
copied_block
(
x
+
c
,
y
,
t0
,
y_lens
,
tpe
)
c_skip
=
self
.
after_proj
(
c
)
else
:
# load from previous c and produce the c for skip connection
c
=
self
.
copied_block
(
c
,
y
,
t0
,
y_lens
,
tpe
)
c_skip
=
self
.
after_proj
(
c
)
return
c
,
c_skip
# The implementation of ControlPixArtHalf net
class
ControlPixArtHalf
(
Module
):
# only support single res model
def
__init__
(
self
,
base_model
:
STDiT
,
copy_blocks_num
:
int
=
13
)
->
None
:
super
().
__init__
()
self
.
base_model
=
base_model
.
eval
()
self
.
controlnet
=
[]
self
.
copy_blocks_num
=
copy_blocks_num
self
.
total_blocks_num
=
len
(
base_model
.
blocks
)
for
p
in
self
.
base_model
.
parameters
():
p
.
requires_grad_
(
False
)
# Copy first copy_blocks_num block
for
i
in
range
(
copy_blocks_num
):
self
.
controlnet
.
append
(
ControlT2IDitBlockHalf
(
base_model
.
blocks
[
i
],
i
))
self
.
controlnet
=
nn
.
ModuleList
(
self
.
controlnet
)
def
__getattr__
(
self
,
name
:
str
)
->
Tensor
or
Module
:
if
name
in
[
'forward'
,
'forward_with_dpmsolver'
,
'forward_with_cfg'
,
'forward_c'
,
'load_state_dict'
]:
return
self
.
__dict__
[
name
]
elif
name
in
[
'base_model'
,
'controlnet'
]:
return
super
().
__getattr__
(
name
)
else
:
return
getattr
(
self
.
base_model
,
name
)
def
forward_c
(
self
,
c
):
# self.h, self.w = c.shape[-2]//self.patch_size, c.shape[-1]//self.patch_size
# pos_embed = torch.from_numpy(get_2d_sincos_pos_embed(self.pos_embed.shape[-1], (self.h, self.w), lewei_scale=self.lewei_scale, base_size=self.base_size)).unsqueeze(0).to(c.device).to(self.dtype)
x
=
self
.
x_embedder
(
c
)
# [B, N, C]
x
=
rearrange
(
x
,
"B (T S) C -> B T S C"
,
T
=
self
.
num_temporal
,
S
=
self
.
num_spatial
)
x
=
x
+
self
.
pos_embed
x
=
rearrange
(
x
,
"B T S C -> B (T S) C"
)
return
x
if
c
is
not
None
else
c
# def forward(self, x, t, c, **kwargs):
# return self.base_model(x, t, c=self.forward_c(c), **kwargs)
def
forward
(
self
,
x
,
timestep
,
y
,
mask
=
None
,
x_mask
=
None
,
c
=
None
):
# modify the original PixArtMS forward function
if
c
is
not
None
:
c
=
c
.
to
(
self
.
dtype
)
c
=
self
.
forward_c
(
c
)
"""
Forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) tensor of class labels
"""
x
=
x
.
to
(
self
.
dtype
)
timestep
=
timestep
.
to
(
self
.
dtype
)
y
=
y
.
to
(
self
.
dtype
)
# embedding
x
=
self
.
x_embedder
(
x
)
# [B, N, C]
x
=
rearrange
(
x
,
"B (T S) C -> B T S C"
,
T
=
self
.
num_temporal
,
S
=
self
.
num_spatial
)
x
=
x
+
self
.
pos_embed
x
=
rearrange
(
x
,
"B T S C -> B (T S) C"
)
# shard over the sequence dim if sp is enabled
if
self
.
enable_sequence_parallelism
:
x
=
split_forward_gather_backward
(
x
,
get_sequence_parallel_group
(),
dim
=
1
,
grad_scale
=
"down"
)
t
=
self
.
t_embedder
(
timestep
,
dtype
=
x
.
dtype
)
# [B, C]
t0
=
self
.
t_block
(
t
)
# [B, C]
y
=
self
.
y_embedder
(
y
,
self
.
training
)
# [B, 1, N_token, C]
if
mask
is
not
None
:
if
mask
.
shape
[
0
]
!=
y
.
shape
[
0
]:
mask
=
mask
.
repeat
(
y
.
shape
[
0
]
//
mask
.
shape
[
0
],
1
)
mask
=
mask
.
squeeze
(
1
).
squeeze
(
1
)
y
=
y
.
squeeze
(
1
).
masked_select
(
mask
.
unsqueeze
(
-
1
)
!=
0
).
view
(
1
,
-
1
,
x
.
shape
[
-
1
])
y_lens
=
mask
.
sum
(
dim
=
1
).
tolist
()
else
:
y_lens
=
[
y
.
shape
[
2
]]
*
y
.
shape
[
0
]
y
=
y
.
squeeze
(
1
).
view
(
1
,
-
1
,
x
.
shape
[
-
1
])
# define the first layer
# y_ori = y
tpe
=
self
.
pos_embed_temporal
x
=
auto_grad_checkpoint
(
self
.
base_model
.
blocks
[
0
],
x
,
y
,
t0
,
y_lens
,
tpe
)
# define the rest layers
# update c
for
index
in
range
(
1
,
self
.
copy_blocks_num
+
1
):
if
index
==
1
:
c
,
c_skip
=
auto_grad_checkpoint
(
self
.
controlnet
[
index
-
1
],
x
,
y
,
t0
,
y_lens
,
c
,
tpe
)
else
:
c
,
c_skip
=
auto_grad_checkpoint
(
self
.
controlnet
[
index
-
1
],
x
,
y
,
t0
,
y_lens
,
c
,
None
)
x
=
auto_grad_checkpoint
(
self
.
base_model
.
blocks
[
index
],
x
+
c_skip
,
y
,
t0
,
y_lens
,
None
)
# update x
for
index
in
range
(
self
.
copy_blocks_num
+
1
,
self
.
total_blocks_num
):
x
=
auto_grad_checkpoint
(
self
.
base_model
.
blocks
[
index
],
x
,
y
,
t0
,
y_lens
,
None
)
# final process
x
=
self
.
final_layer
(
x
,
t
)
# [B, N, C=T_p * H_p * W_p * C_out]
x
=
self
.
unpatchify
(
x
)
# [B, C_out, T, H, W]
# cast to float32 for better accuracy
x
=
x
.
to
(
torch
.
float32
)
return
x
def
load_state_dict
(
self
,
state_dict
:
Mapping
[
str
,
Any
],
strict
:
bool
=
True
):
if
all
((
k
.
startswith
(
'base_model'
)
or
k
.
startswith
(
'controlnet'
))
for
k
in
state_dict
.
keys
()):
return
super
().
load_state_dict
(
state_dict
,
strict
)
else
:
new_key
=
{}
for
k
in
state_dict
.
keys
():
new_key
[
k
]
=
re
.
sub
(
r
"(blocks\.\d+)(.*)"
,
r
"\1.base_block\2"
,
k
)
for
k
,
v
in
new_key
.
items
():
if
k
!=
v
:
print
(
f
"replace
{
k
}
to
{
v
}
"
)
state_dict
[
v
]
=
state_dict
.
pop
(
k
)
return
self
.
base_model
.
load_state_dict
(
state_dict
,
strict
)
def
unpatchify
(
self
,
x
):
"""
Args:
x (torch.Tensor): of shape [B, N, C]
Return:
x (torch.Tensor): of shape [B, C_out, T, H, W]
"""
N_t
,
N_h
,
N_w
=
[
self
.
input_size
[
i
]
//
self
.
patch_size
[
i
]
for
i
in
range
(
3
)]
T_p
,
H_p
,
W_p
=
self
.
patch_size
x
=
rearrange
(
x
,
"B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)"
,
N_t
=
N_t
,
N_h
=
N_h
,
N_w
=
N_w
,
T_p
=
T_p
,
H_p
=
H_p
,
W_p
=
W_p
,
C_out
=
self
.
out_channels
,
)
return
x
def
unpatchify_old
(
self
,
x
):
c
=
self
.
out_channels
t
,
h
,
w
=
[
self
.
input_size
[
i
]
//
self
.
patch_size
[
i
]
for
i
in
range
(
3
)]
pt
,
ph
,
pw
=
self
.
patch_size
x
=
x
.
reshape
(
shape
=
(
x
.
shape
[
0
],
t
,
h
,
w
,
pt
,
ph
,
pw
,
c
))
x
=
rearrange
(
x
,
"n t h w r p q c -> n c t r h p w q"
)
imgs
=
x
.
reshape
(
shape
=
(
x
.
shape
[
0
],
c
,
t
*
pt
,
h
*
ph
,
w
*
pw
))
return
imgs
def
get_spatial_pos_embed
(
self
,
grid_size
=
None
):
if
grid_size
is
None
:
grid_size
=
self
.
input_size
[
1
:]
pos_embed
=
get_2d_sincos_pos_embed
(
self
.
hidden_size
,
(
grid_size
[
0
]
//
self
.
patch_size
[
1
],
grid_size
[
1
]
//
self
.
patch_size
[
2
]),
scale
=
self
.
space_scale
,
)
pos_embed
=
torch
.
from_numpy
(
pos_embed
).
float
().
unsqueeze
(
0
).
requires_grad_
(
False
)
return
pos_embed
def
get_temporal_pos_embed
(
self
):
pos_embed
=
get_1d_sincos_pos_embed
(
self
.
hidden_size
,
self
.
input_size
[
0
]
//
self
.
patch_size
[
0
],
scale
=
self
.
time_scale
,
)
pos_embed
=
torch
.
from_numpy
(
pos_embed
).
float
().
unsqueeze
(
0
).
requires_grad_
(
False
)
return
pos_embed
def
freeze_not_temporal
(
self
):
for
n
,
p
in
self
.
named_parameters
():
if
"attn_temp"
not
in
n
:
p
.
requires_grad
=
False
def
freeze_text
(
self
):
for
n
,
p
in
self
.
named_parameters
():
if
"cross_attn"
in
n
:
p
.
requires_grad
=
False
def
initialize_temporal
(
self
):
for
block
in
self
.
blocks
:
nn
.
init
.
constant_
(
block
.
attn_temp
.
proj
.
weight
,
0
)
nn
.
init
.
constant_
(
block
.
attn_temp
.
proj
.
bias
,
0
)
def
initialize_weights
(
self
):
# Initialize transformer layers:
def
_basic_init
(
module
):
if
isinstance
(
module
,
nn
.
Linear
):
torch
.
nn
.
init
.
xavier_uniform_
(
module
.
weight
)
if
module
.
bias
is
not
None
:
nn
.
init
.
constant_
(
module
.
bias
,
0
)
self
.
apply
(
_basic_init
)
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w
=
self
.
x_embedder
.
proj
.
weight
.
data
nn
.
init
.
xavier_uniform_
(
w
.
view
([
w
.
shape
[
0
],
-
1
]))
# Initialize timestep embedding MLP:
nn
.
init
.
normal_
(
self
.
t_embedder
.
mlp
[
0
].
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
t_embedder
.
mlp
[
2
].
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
t_block
[
1
].
weight
,
std
=
0.02
)
# Initialize caption embedding MLP:
nn
.
init
.
normal_
(
self
.
y_embedder
.
y_proj
.
fc1
.
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
y_embedder
.
y_proj
.
fc2
.
weight
,
std
=
0.02
)
# Zero-out adaLN modulation layers in PixArt blocks:
for
block
in
self
.
blocks
:
nn
.
init
.
constant_
(
block
.
cross_attn
.
proj
.
weight
,
0
)
nn
.
init
.
constant_
(
block
.
cross_attn
.
proj
.
bias
,
0
)
# Zero-out output layers:
nn
.
init
.
constant_
(
self
.
final_layer
.
linear
.
weight
,
0
)
nn
.
init
.
constant_
(
self
.
final_layer
.
linear
.
bias
,
0
)
utils_data/opensora/models/stdit/stdit_controlnet_freq.py
0 → 100644
View file @
1f5da520
import
re
import
torch
import
torch.nn
as
nn
from
copy
import
deepcopy
from
torch
import
Tensor
from
torch.nn
import
Module
,
Linear
,
init
from
typing
import
Any
,
Mapping
# import pywt
import
torch.fft
import
numpy
as
np
import
sys
sys
.
path
.
append
(
"/mnt/bn/videodataset/VSR/VSR"
)
from
einops
import
rearrange
# from diffusion.model.nets import PixArtMSBlock, PixArtMS, PixArt
from
opensora.models.stdit.stdit
import
STDiTBlock
,
STDiT
from
opensora.models.layers.blocks
import
(
Attention
,
CaptionEmbedder
,
MultiHeadCrossAttention
,
PatchEmbed3D
,
SeqParallelAttention
,
SeqParallelMultiHeadCrossAttention
,
T2IFinalLayer
,
TimestepEmbedder
,
approx_gelu
,
get_1d_sincos_pos_embed
,
get_2d_sincos_pos_embed
,
get_layernorm
,
t2i_modulate
,
)
from
opensora.acceleration.checkpoint
import
auto_grad_checkpoint
import
torch.nn.functional
as
F
from
einops
import
rearrange
# The implementation of ControlNet-Half architrecture
# https://github.com/lllyasviel/ControlNet/discussions/188
class
ControlT2IDitBlockHalf
(
Module
):
def
__init__
(
self
,
base_block
:
STDiTBlock
,
block_index
:
0
)
->
None
:
super
().
__init__
()
self
.
copied_block
=
deepcopy
(
base_block
)
self
.
block_index
=
block_index
for
p
in
self
.
copied_block
.
parameters
():
p
.
requires_grad_
(
True
)
self
.
copied_block
.
load_state_dict
(
base_block
.
state_dict
())
self
.
copied_block
.
train
()
self
.
hidden_size
=
hidden_size
=
base_block
.
hidden_size
if
self
.
block_index
==
0
:
self
.
before_proj
=
Linear
(
hidden_size
,
hidden_size
)
init
.
zeros_
(
self
.
before_proj
.
weight
)
init
.
zeros_
(
self
.
before_proj
.
bias
)
self
.
after_proj
=
Linear
(
hidden_size
,
hidden_size
)
init
.
zeros_
(
self
.
after_proj
.
weight
)
init
.
zeros_
(
self
.
after_proj
.
bias
)
def
forward
(
self
,
x
,
y
,
t0
,
y_lens
,
c
,
tpe
,
hf_fea
,
lf_fea
,
temp_fea
):
if
self
.
block_index
==
0
:
# the first block
c
=
self
.
before_proj
(
c
)
c
=
self
.
copied_block
(
x
+
c
,
y
,
t0
,
y_lens
,
tpe
,
hf_fea
,
lf_fea
,
temp_fea
)
c_skip
=
self
.
after_proj
(
c
)
else
:
# load from previous c and produce the c for skip connection
c
=
self
.
copied_block
(
c
,
y
,
t0
,
y_lens
,
tpe
,
hf_fea
,
lf_fea
,
temp_fea
)
c_skip
=
self
.
after_proj
(
c
)
return
c
,
c_skip
# The implementation of ControlPixArtHalf net
class
ControlPixArtHalf
(
Module
):
# only support single res model
def
__init__
(
self
,
base_model
:
STDiT
,
copy_blocks_num
:
int
=
13
)
->
None
:
super
().
__init__
()
self
.
base_model
=
base_model
self
.
controlnet
=
[]
self
.
copy_blocks_num
=
copy_blocks_num
self
.
total_blocks_num
=
len
(
base_model
.
blocks
)
for
p
in
self
.
base_model
.
parameters
():
p
.
requires_grad_
(
False
)
# Copy first copy_blocks_num block
for
i
in
range
(
copy_blocks_num
):
self
.
controlnet
.
append
(
ControlT2IDitBlockHalf
(
base_model
.
blocks
[
i
],
i
))
self
.
controlnet
=
nn
.
ModuleList
(
self
.
controlnet
)
def
__getattr__
(
self
,
name
:
str
)
->
Tensor
or
Module
:
if
name
in
[
'forward'
,
'forward_with_dpmsolver'
,
'forward_with_cfg'
,
'forward_c'
,
'load_state_dict'
]:
return
self
.
__dict__
[
name
]
elif
name
in
[
'base_model'
,
'controlnet'
]:
return
super
().
__getattr__
(
name
)
else
:
return
getattr
(
self
.
base_model
,
name
)
def
forward_c
(
self
,
c
):
### Controlnet Input ###
x
=
self
.
x_embedder
(
c
)
# [B, N, 1152]
x
=
rearrange
(
x
,
"B (T S) C -> B T S C"
,
T
=
self
.
num_temporal
,
S
=
self
.
num_spatial
)
x
=
x
+
self
.
pos_embed
x
=
rearrange
(
x
,
"B T S C -> B (T S) C"
)
return
x
def
forward_hf
(
self
,
hf_part
):
### Controlnet Input ###
x
=
self
.
hf_embedder
(
hf_part
)
# [B, N, 1152]
x
=
rearrange
(
x
,
"B (T S) C -> B T S C"
,
T
=
self
.
num_temporal
,
S
=
self
.
num_spatial
)
x
=
x
+
self
.
pos_embed
x
=
rearrange
(
x
,
"B T S C -> B (T S) C"
)
return
x
def
forward_lf
(
self
,
lf_part
):
### Controlnet Input ###
x
=
self
.
lf_embedder
(
lf_part
)
# [B, N, 1152]
x
=
rearrange
(
x
,
"B (T S) C -> B T S C"
,
T
=
self
.
num_temporal
,
S
=
self
.
num_spatial
)
x
=
x
+
self
.
pos_embed
x
=
rearrange
(
x
,
"B T S C -> B (T S) C"
)
return
x
def
forward
(
self
,
x
,
timestep
,
y
,
mask
=
None
,
c
=
None
,
lr
=
None
):
# modify the original PixArtMS forward function
if
c
is
not
None
:
c
=
c
.
to
(
self
.
dtype
)
c
=
auto_grad_checkpoint
(
self
.
forward_c
,
c
)
# generate spatial & temporal information
_
,
hf_part
,
lf_part
=
self
.
fdie
.
spatial_forward
(
lr
)
hf_fea
=
auto_grad_checkpoint
(
self
.
forward_hf
,
hf_part
)
lf_fea
=
auto_grad_checkpoint
(
self
.
forward_lf
,
lf_part
)
temp_fea
=
self
.
fdie
.
temporal_forward
(
lf_fea
)
"""
Forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) tensor of class labels
"""
x
=
x
.
to
(
self
.
dtype
)
timestep
=
timestep
.
to
(
self
.
dtype
)
y
=
y
.
to
(
self
.
dtype
)
# embedding
x
=
self
.
x_embedder
(
x
)
# [B, N, C]
x
=
rearrange
(
x
,
"B (T S) C -> B T S C"
,
T
=
self
.
num_temporal
,
S
=
self
.
num_spatial
)
x
=
x
+
self
.
pos_embed
x
=
rearrange
(
x
,
"B T S C -> B (T S) C"
)
# shard over the sequence dim if sp is enabled
if
self
.
enable_sequence_parallelism
:
x
=
split_forward_gather_backward
(
x
,
get_sequence_parallel_group
(),
dim
=
1
,
grad_scale
=
"down"
)
t
=
self
.
t_embedder
(
timestep
,
dtype
=
x
.
dtype
)
# [B, C]
t0
=
self
.
t_block
(
t
)
# [B, C]
y
=
self
.
y_embedder
(
y
,
self
.
training
)
# [B, 1, N_token, C]
if
mask
is
not
None
:
if
mask
.
shape
[
0
]
!=
y
.
shape
[
0
]:
mask
=
mask
.
repeat
(
y
.
shape
[
0
]
//
mask
.
shape
[
0
],
1
)
mask
=
mask
.
squeeze
(
1
).
squeeze
(
1
)
y
=
y
.
squeeze
(
1
).
masked_select
(
mask
.
unsqueeze
(
-
1
)
!=
0
).
view
(
1
,
-
1
,
x
.
shape
[
-
1
])
y_lens
=
mask
.
sum
(
dim
=
1
).
tolist
()
else
:
y_lens
=
[
y
.
shape
[
2
]]
*
y
.
shape
[
0
]
y
=
y
.
squeeze
(
1
).
view
(
1
,
-
1
,
x
.
shape
[
-
1
])
# define the first layer
# y_ori = y
tpe
=
self
.
pos_embed_temporal
x
=
auto_grad_checkpoint
(
self
.
base_model
.
blocks
[
0
],
x
,
y
,
t0
,
y_lens
,
tpe
,
hf_fea
,
lf_fea
,
temp_fea
)
# define the rest layers
# update c
for
index
in
range
(
1
,
self
.
copy_blocks_num
+
1
):
if
index
==
1
:
c
,
c_skip
=
auto_grad_checkpoint
(
self
.
controlnet
[
index
-
1
],
x
,
y
,
t0
,
y_lens
,
c
,
tpe
,
hf_fea
,
lf_fea
,
temp_fea
)
else
:
c
,
c_skip
=
auto_grad_checkpoint
(
self
.
controlnet
[
index
-
1
],
x
,
y
,
t0
,
y_lens
,
c
,
None
,
hf_fea
,
lf_fea
,
temp_fea
)
x
=
auto_grad_checkpoint
(
self
.
base_model
.
blocks
[
index
],
x
+
c_skip
,
y
,
t0
,
y_lens
,
None
,
hf_fea
,
lf_fea
,
temp_fea
)
# update x
for
index
in
range
(
self
.
copy_blocks_num
+
1
,
self
.
total_blocks_num
):
x
=
auto_grad_checkpoint
(
self
.
base_model
.
blocks
[
index
],
x
,
y
,
t0
,
y_lens
,
None
,
hf_fea
,
lf_fea
,
temp_fea
)
# final process
x
=
self
.
final_layer
(
x
,
t
)
# [B, N, C=T_p * H_p * W_p * C_out]
x
=
self
.
unpatchify
(
x
)
# [B, C_out, T, H, W]
# cast to float32 for better accuracy
x
=
x
.
to
(
torch
.
float32
)
return
x
def
load_state_dict
(
self
,
state_dict
:
Mapping
[
str
,
Any
],
strict
:
bool
=
True
):
if
all
((
k
.
startswith
(
'base_model'
)
or
k
.
startswith
(
'controlnet'
))
for
k
in
state_dict
.
keys
()):
return
super
().
load_state_dict
(
state_dict
,
strict
)
else
:
new_key
=
{}
for
k
in
state_dict
.
keys
():
new_key
[
k
]
=
re
.
sub
(
r
"(blocks\.\d+)(.*)"
,
r
"\1.base_block\2"
,
k
)
for
k
,
v
in
new_key
.
items
():
if
k
!=
v
:
print
(
f
"replace
{
k
}
to
{
v
}
"
)
state_dict
[
v
]
=
state_dict
.
pop
(
k
)
return
self
.
base_model
.
load_state_dict
(
state_dict
,
strict
)
def
unpatchify
(
self
,
x
):
"""
Args:
x (torch.Tensor): of shape [B, N, C]
Return:
x (torch.Tensor): of shape [B, C_out, T, H, W]
"""
N_t
,
N_h
,
N_w
=
[
self
.
input_size
[
i
]
//
self
.
patch_size
[
i
]
for
i
in
range
(
3
)]
T_p
,
H_p
,
W_p
=
self
.
patch_size
x
=
rearrange
(
x
,
"B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)"
,
N_t
=
N_t
,
N_h
=
N_h
,
N_w
=
N_w
,
T_p
=
T_p
,
H_p
=
H_p
,
W_p
=
W_p
,
C_out
=
self
.
out_channels
,
)
return
x
def
unpatchify_old
(
self
,
x
):
c
=
self
.
out_channels
t
,
h
,
w
=
[
self
.
input_size
[
i
]
//
self
.
patch_size
[
i
]
for
i
in
range
(
3
)]
pt
,
ph
,
pw
=
self
.
patch_size
x
=
x
.
reshape
(
shape
=
(
x
.
shape
[
0
],
t
,
h
,
w
,
pt
,
ph
,
pw
,
c
))
x
=
rearrange
(
x
,
"n t h w r p q c -> n c t r h p w q"
)
imgs
=
x
.
reshape
(
shape
=
(
x
.
shape
[
0
],
c
,
t
*
pt
,
h
*
ph
,
w
*
pw
))
return
imgs
def
get_spatial_pos_embed
(
self
,
grid_size
=
None
):
if
grid_size
is
None
:
grid_size
=
self
.
input_size
[
1
:]
pos_embed
=
get_2d_sincos_pos_embed
(
self
.
hidden_size
,
(
grid_size
[
0
]
//
self
.
patch_size
[
1
],
grid_size
[
1
]
//
self
.
patch_size
[
2
]),
scale
=
self
.
space_scale
,
)
pos_embed
=
torch
.
from_numpy
(
pos_embed
).
float
().
unsqueeze
(
0
).
requires_grad_
(
False
)
return
pos_embed
def
get_temporal_pos_embed
(
self
):
pos_embed
=
get_1d_sincos_pos_embed
(
self
.
hidden_size
,
self
.
input_size
[
0
]
//
self
.
patch_size
[
0
],
scale
=
self
.
time_scale
,
)
pos_embed
=
torch
.
from_numpy
(
pos_embed
).
float
().
unsqueeze
(
0
).
requires_grad_
(
False
)
return
pos_embed
def
freeze_not_temporal
(
self
):
for
n
,
p
in
self
.
named_parameters
():
if
"attn_temp"
not
in
n
:
p
.
requires_grad
=
False
def
freeze_text
(
self
):
for
n
,
p
in
self
.
named_parameters
():
if
"cross_attn"
in
n
:
p
.
requires_grad
=
False
def
initialize_temporal
(
self
):
for
block
in
self
.
blocks
:
nn
.
init
.
constant_
(
block
.
attn_temp
.
proj
.
weight
,
0
)
nn
.
init
.
constant_
(
block
.
attn_temp
.
proj
.
bias
,
0
)
def
initialize_weights
(
self
):
# Initialize transformer layers:
def
_basic_init
(
module
):
if
isinstance
(
module
,
nn
.
Linear
):
torch
.
nn
.
init
.
xavier_uniform_
(
module
.
weight
)
if
module
.
bias
is
not
None
:
nn
.
init
.
constant_
(
module
.
bias
,
0
)
self
.
apply
(
_basic_init
)
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w
=
self
.
x_embedder
.
proj
.
weight
.
data
nn
.
init
.
xavier_uniform_
(
w
.
view
([
w
.
shape
[
0
],
-
1
]))
# Initialize timestep embedding MLP:
nn
.
init
.
normal_
(
self
.
t_embedder
.
mlp
[
0
].
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
t_embedder
.
mlp
[
2
].
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
t_block
[
1
].
weight
,
std
=
0.02
)
# Initialize caption embedding MLP:
nn
.
init
.
normal_
(
self
.
y_embedder
.
y_proj
.
fc1
.
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
y_embedder
.
y_proj
.
fc2
.
weight
,
std
=
0.02
)
# Zero-out adaLN modulation layers in PixArt blocks:
for
block
in
self
.
blocks
:
nn
.
init
.
constant_
(
block
.
cross_attn
.
proj
.
weight
,
0
)
nn
.
init
.
constant_
(
block
.
cross_attn
.
proj
.
bias
,
0
)
# Zero-out output layers:
nn
.
init
.
constant_
(
self
.
final_layer
.
linear
.
weight
,
0
)
nn
.
init
.
constant_
(
self
.
final_layer
.
linear
.
bias
,
0
)
utils_data/opensora/models/stdit/stdit_controlnet_mvdit.py
0 → 100644
View file @
1f5da520
import
re
import
torch
import
torch.nn
as
nn
from
copy
import
deepcopy
from
torch
import
Tensor
from
torch.nn
import
Module
,
Linear
,
init
from
typing
import
Any
,
Mapping
import
sys
sys
.
path
.
append
(
"/home/test/Workspace/ruixie/Open-Sora"
)
from
einops
import
rearrange
# from diffusion.model.nets import PixArtMSBlock, PixArtMS, PixArt
from
opensora.models.stdit.stdit
import
STDiTBlock
,
STDiT
from
opensora.models.layers.blocks
import
(
Attention
,
CaptionEmbedder
,
MultiHeadCrossAttention
,
PatchEmbed3D
,
SeqParallelAttention
,
SeqParallelMultiHeadCrossAttention
,
T2IFinalLayer
,
TimestepEmbedder
,
approx_gelu
,
get_1d_sincos_pos_embed
,
get_2d_sincos_pos_embed
,
get_layernorm
,
t2i_modulate
,
)
from
opensora.acceleration.checkpoint
import
auto_grad_checkpoint
# The implementation of ControlNet-Half architrecture
# https://github.com/lllyasviel/ControlNet/discussions/188
class
ControlT2IDitBlockHalf
(
Module
):
def
__init__
(
self
,
base_block
:
STDiTBlock
,
block_index
:
0
)
->
None
:
super
().
__init__
()
self
.
copied_block
=
deepcopy
(
base_block
)
self
.
block_index
=
block_index
for
p
in
self
.
copied_block
.
parameters
():
p
.
requires_grad_
(
True
)
self
.
copied_block
.
load_state_dict
(
base_block
.
state_dict
())
self
.
copied_block
.
train
()
self
.
hidden_size
=
hidden_size
=
base_block
.
hidden_size
if
self
.
block_index
==
0
:
self
.
before_proj
=
Linear
(
hidden_size
,
hidden_size
)
init
.
zeros_
(
self
.
before_proj
.
weight
)
init
.
zeros_
(
self
.
before_proj
.
bias
)
self
.
after_proj
=
Linear
(
hidden_size
,
hidden_size
)
init
.
zeros_
(
self
.
after_proj
.
weight
)
init
.
zeros_
(
self
.
after_proj
.
bias
)
def
forward
(
self
,
x
,
y
,
t0
,
t_y
,
t0_tmep
,
t_y_tmep
,
mask
,
c
,
tpe
):
if
self
.
block_index
==
0
:
# the first block
c
=
self
.
before_proj
(
c
)
c
,
y
=
self
.
copied_block
(
x
+
c
,
y
,
t0
,
t_y
,
t0_tmep
,
t_y_tmep
,
mask
,
tpe
)
c_skip
=
self
.
after_proj
(
c
)
else
:
# load from previous c and produce the c for skip connection
c
,
y
=
self
.
copied_block
(
c
,
y
,
t0
,
t_y
,
t0_tmep
,
t_y_tmep
,
mask
,
tpe
)
c_skip
=
self
.
after_proj
(
c
)
return
c
,
c_skip
,
y
# The implementation of ControlPixArtHalf net
class
ControlPixArtHalf
(
Module
):
# only support single res model
def
__init__
(
self
,
base_model
:
STDiT
,
copy_blocks_num
:
int
=
13
)
->
None
:
super
().
__init__
()
self
.
base_model
=
base_model
.
eval
()
self
.
controlnet
=
[]
self
.
copy_blocks_num
=
copy_blocks_num
self
.
total_blocks_num
=
len
(
base_model
.
blocks
)
for
p
in
self
.
base_model
.
parameters
():
p
.
requires_grad_
(
False
)
# Copy first copy_blocks_num block
for
i
in
range
(
copy_blocks_num
):
self
.
controlnet
.
append
(
ControlT2IDitBlockHalf
(
base_model
.
blocks
[
i
],
i
))
self
.
controlnet
=
nn
.
ModuleList
(
self
.
controlnet
)
def
__getattr__
(
self
,
name
:
str
)
->
Tensor
or
Module
:
if
name
in
[
'forward'
,
'forward_with_dpmsolver'
,
'forward_with_cfg'
,
'forward_c'
,
'load_state_dict'
]:
return
self
.
__dict__
[
name
]
elif
name
in
[
'base_model'
,
'controlnet'
]:
return
super
().
__getattr__
(
name
)
else
:
return
getattr
(
self
.
base_model
,
name
)
def
forward_c
(
self
,
c
):
# self.h, self.w = c.shape[-2]//self.patch_size, c.shape[-1]//self.patch_size
# pos_embed = torch.from_numpy(get_2d_sincos_pos_embed(self.pos_embed.shape[-1], (self.h, self.w), lewei_scale=self.lewei_scale, base_size=self.base_size)).unsqueeze(0).to(c.device).to(self.dtype)
x
=
self
.
x_embedder
(
c
)
# [B, N, C]
x
=
rearrange
(
x
,
"B (T S) C -> B T S C"
,
T
=
self
.
num_temporal
,
S
=
self
.
num_spatial
)
x
=
x
+
self
.
pos_embed
x
=
rearrange
(
x
,
"B T S C -> B (T S) C"
)
return
x
if
c
is
not
None
else
c
# def forward(self, x, t, c, **kwargs):
# return self.base_model(x, t, c=self.forward_c(c), **kwargs)
def
forward
(
self
,
x
,
timestep
,
y
,
mask
=
None
,
x_mask
=
None
,
c
=
None
):
# modify the original PixArtMS forward function
if
c
is
not
None
:
# print("Process condition")
c
=
c
.
to
(
self
.
dtype
)
c
=
self
.
forward_c
(
c
)
"""
Forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) tensor of class labels
"""
x
=
x
.
to
(
self
.
dtype
)
timestep
=
timestep
.
to
(
self
.
dtype
)
y
=
y
.
to
(
self
.
dtype
)
# embedding
x
=
self
.
x_embedder
(
x
)
# [B, N, C]
x
=
rearrange
(
x
,
"B (T S) C -> B T S C"
,
T
=
self
.
num_temporal
,
S
=
self
.
num_spatial
)
x
=
x
+
self
.
pos_embed
x
=
rearrange
(
x
,
"B T S C -> B (T S) C"
)
# shard over the sequence dim if sp is enabled
if
self
.
enable_sequence_parallelism
:
x
=
split_forward_gather_backward
(
x
,
get_sequence_parallel_group
(),
dim
=
1
,
grad_scale
=
"down"
)
# t = self.t_embedder(timestep, dtype=x.dtype) # [B, C]
# t_spc_mlp = self.t_block(t) # [B, 6*C]
# t_tmp_mlp = self.t_block_temp(t) # [B, 3*C]
t
=
self
.
t_embedder
(
timestep
,
dtype
=
x
.
dtype
)
# [B, C]
t0
=
self
.
t_block
(
t
)
# [B, C]
t_y
=
self
.
t_block_y
(
t
)
t0_tmep
=
self
.
t_block_temp
(
t
)
# [B, C]
t_y_tmep
=
self
.
t_block_y_temp
(
t
)
y
=
self
.
y_embedder
(
y
,
self
.
training
)
# [B, 1, N_token, C]
if
mask
is
not
None
:
if
mask
.
shape
[
0
]
!=
y
.
shape
[
0
]:
mask
=
mask
.
repeat
(
y
.
shape
[
0
]
//
mask
.
shape
[
0
],
1
)
mask
=
mask
.
squeeze
(
1
).
squeeze
(
1
)
y
=
y
.
squeeze
(
1
).
masked_select
(
mask
.
unsqueeze
(
-
1
)
!=
0
).
view
(
1
,
-
1
,
x
.
shape
[
-
1
])
y_lens
=
mask
.
sum
(
dim
=
1
).
tolist
()
else
:
y_lens
=
[
y
.
shape
[
2
]]
*
y
.
shape
[
0
]
y
=
y
.
squeeze
(
1
).
view
(
1
,
-
1
,
x
.
shape
[
-
1
])
# define the first layer
y_ori
=
y
tpe
=
self
.
pos_embed_temporal
x
,
y_x
=
auto_grad_checkpoint
(
self
.
base_model
.
blocks
[
0
],
x
,
y_ori
,
t0
,
t_y
,
t0_tmep
,
t_y_tmep
,
y_lens
,
tpe
)
# define the rest layers
# update c
for
index
in
range
(
1
,
self
.
copy_blocks_num
+
1
):
if
index
==
1
:
c
,
c_skip
,
y_c
=
auto_grad_checkpoint
(
self
.
controlnet
[
index
-
1
],
x
,
y_ori
,
t0
,
t_y
,
t0_tmep
,
t_y_tmep
,
y_lens
,
c
,
tpe
)
else
:
c
,
c_skip
,
y_c
=
auto_grad_checkpoint
(
self
.
controlnet
[
index
-
1
],
x
,
y_c
,
t0
,
t_y
,
t0_tmep
,
t_y_tmep
,
y_lens
,
c
,
None
)
x
,
y_x
=
auto_grad_checkpoint
(
self
.
base_model
.
blocks
[
index
],
x
+
c_skip
,
y_x
,
t0
,
t_y
,
t0_tmep
,
t_y_tmep
,
y_lens
,
None
)
# update x
for
index
in
range
(
self
.
copy_blocks_num
+
1
,
self
.
total_blocks_num
):
x
,
y_x
=
auto_grad_checkpoint
(
self
.
base_model
.
blocks
[
index
],
x
,
y_x
,
t0
,
t_y
,
t0_tmep
,
t_y_tmep
,
y_lens
,
None
)
# final process
x
=
self
.
final_layer
(
x
,
t
)
# [B, N, C=T_p * H_p * W_p * C_out]
x
=
self
.
unpatchify
(
x
)
# [B, C_out, T, H, W]
# cast to float32 for better accuracy
x
=
x
.
to
(
torch
.
float32
)
return
x
def
load_state_dict
(
self
,
state_dict
:
Mapping
[
str
,
Any
],
strict
:
bool
=
True
):
if
all
((
k
.
startswith
(
'base_model'
)
or
k
.
startswith
(
'controlnet'
))
for
k
in
state_dict
.
keys
()):
return
super
().
load_state_dict
(
state_dict
,
strict
)
else
:
new_key
=
{}
for
k
in
state_dict
.
keys
():
new_key
[
k
]
=
re
.
sub
(
r
"(blocks\.\d+)(.*)"
,
r
"\1.base_block\2"
,
k
)
for
k
,
v
in
new_key
.
items
():
if
k
!=
v
:
print
(
f
"replace
{
k
}
to
{
v
}
"
)
state_dict
[
v
]
=
state_dict
.
pop
(
k
)
return
self
.
base_model
.
load_state_dict
(
state_dict
,
strict
)
def
unpatchify
(
self
,
x
):
"""
Args:
x (torch.Tensor): of shape [B, N, C]
Return:
x (torch.Tensor): of shape [B, C_out, T, H, W]
"""
N_t
,
N_h
,
N_w
=
[
self
.
input_size
[
i
]
//
self
.
patch_size
[
i
]
for
i
in
range
(
3
)]
T_p
,
H_p
,
W_p
=
self
.
patch_size
x
=
rearrange
(
x
,
"B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)"
,
N_t
=
N_t
,
N_h
=
N_h
,
N_w
=
N_w
,
T_p
=
T_p
,
H_p
=
H_p
,
W_p
=
W_p
,
C_out
=
self
.
out_channels
,
)
return
x
def
unpatchify_old
(
self
,
x
):
c
=
self
.
out_channels
t
,
h
,
w
=
[
self
.
input_size
[
i
]
//
self
.
patch_size
[
i
]
for
i
in
range
(
3
)]
pt
,
ph
,
pw
=
self
.
patch_size
x
=
x
.
reshape
(
shape
=
(
x
.
shape
[
0
],
t
,
h
,
w
,
pt
,
ph
,
pw
,
c
))
x
=
rearrange
(
x
,
"n t h w r p q c -> n c t r h p w q"
)
imgs
=
x
.
reshape
(
shape
=
(
x
.
shape
[
0
],
c
,
t
*
pt
,
h
*
ph
,
w
*
pw
))
return
imgs
def
get_spatial_pos_embed
(
self
,
grid_size
=
None
):
if
grid_size
is
None
:
grid_size
=
self
.
input_size
[
1
:]
pos_embed
=
get_2d_sincos_pos_embed
(
self
.
hidden_size
,
(
grid_size
[
0
]
//
self
.
patch_size
[
1
],
grid_size
[
1
]
//
self
.
patch_size
[
2
]),
scale
=
self
.
space_scale
,
)
pos_embed
=
torch
.
from_numpy
(
pos_embed
).
float
().
unsqueeze
(
0
).
requires_grad_
(
False
)
return
pos_embed
def
get_temporal_pos_embed
(
self
):
pos_embed
=
get_1d_sincos_pos_embed
(
self
.
hidden_size
,
self
.
input_size
[
0
]
//
self
.
patch_size
[
0
],
scale
=
self
.
time_scale
,
)
pos_embed
=
torch
.
from_numpy
(
pos_embed
).
float
().
unsqueeze
(
0
).
requires_grad_
(
False
)
return
pos_embed
def
freeze_not_temporal
(
self
):
for
n
,
p
in
self
.
named_parameters
():
if
"attn_temp"
not
in
n
:
p
.
requires_grad
=
False
def
freeze_text
(
self
):
for
n
,
p
in
self
.
named_parameters
():
if
"cross_attn"
in
n
:
p
.
requires_grad
=
False
def
initialize_temporal
(
self
):
for
block
in
self
.
blocks
:
nn
.
init
.
constant_
(
block
.
attn_temp
.
proj
.
weight
,
0
)
nn
.
init
.
constant_
(
block
.
attn_temp
.
proj
.
bias
,
0
)
def
initialize_weights
(
self
):
# Initialize transformer layers:
def
_basic_init
(
module
):
if
isinstance
(
module
,
nn
.
Linear
):
torch
.
nn
.
init
.
xavier_uniform_
(
module
.
weight
)
if
module
.
bias
is
not
None
:
nn
.
init
.
constant_
(
module
.
bias
,
0
)
self
.
apply
(
_basic_init
)
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w
=
self
.
x_embedder
.
proj
.
weight
.
data
nn
.
init
.
xavier_uniform_
(
w
.
view
([
w
.
shape
[
0
],
-
1
]))
# Initialize timestep embedding MLP:
nn
.
init
.
normal_
(
self
.
t_embedder
.
mlp
[
0
].
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
t_embedder
.
mlp
[
2
].
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
t_block
[
1
].
weight
,
std
=
0.02
)
# Initialize caption embedding MLP:
nn
.
init
.
normal_
(
self
.
y_embedder
.
y_proj
.
fc1
.
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
y_embedder
.
y_proj
.
fc2
.
weight
,
std
=
0.02
)
# Zero-out adaLN modulation layers in PixArt blocks:
for
block
in
self
.
blocks
:
nn
.
init
.
constant_
(
block
.
cross_attn
.
proj
.
weight
,
0
)
nn
.
init
.
constant_
(
block
.
cross_attn
.
proj
.
bias
,
0
)
# Zero-out output layers:
nn
.
init
.
constant_
(
self
.
final_layer
.
linear
.
weight
,
0
)
nn
.
init
.
constant_
(
self
.
final_layer
.
linear
.
bias
,
0
)
utils_data/opensora/models/stdit/stdit_controlnet_qknorm.py
0 → 100644
View file @
1f5da520
import
re
import
torch
import
torch.nn
as
nn
from
copy
import
deepcopy
from
torch
import
Tensor
from
torch.nn
import
Module
,
Linear
,
init
from
typing
import
Any
,
Mapping
import
sys
sys
.
path
.
append
(
"/home/test/Workspace/ruixie/Open-Sora"
)
from
einops
import
rearrange
# from diffusion.model.nets import PixArtMSBlock, PixArtMS, PixArt
from
opensora.models.stdit.stdit
import
STDiTBlock
,
STDiT
from
opensora.models.layers.blocks
import
(
Attention
,
CaptionEmbedder
,
MultiHeadCrossAttention
,
PatchEmbed3D
,
SeqParallelAttention
,
SeqParallelMultiHeadCrossAttention
,
T2IFinalLayer
,
TimestepEmbedder
,
approx_gelu
,
get_1d_sincos_pos_embed
,
get_2d_sincos_pos_embed
,
get_layernorm
,
t2i_modulate
,
)
from
opensora.acceleration.checkpoint
import
auto_grad_checkpoint
# The implementation of ControlNet-Half architrecture
# https://github.com/lllyasviel/ControlNet/discussions/188
class
ControlT2IDitBlockHalf
(
Module
):
def
__init__
(
self
,
base_block
:
STDiTBlock
,
block_index
:
0
)
->
None
:
super
().
__init__
()
self
.
copied_block
=
deepcopy
(
base_block
)
self
.
block_index
=
block_index
for
p
in
self
.
copied_block
.
parameters
():
p
.
requires_grad_
(
True
)
self
.
copied_block
.
load_state_dict
(
base_block
.
state_dict
())
self
.
copied_block
.
train
()
self
.
hidden_size
=
hidden_size
=
base_block
.
hidden_size
if
self
.
block_index
==
0
:
self
.
before_proj
=
Linear
(
hidden_size
,
hidden_size
)
init
.
zeros_
(
self
.
before_proj
.
weight
)
init
.
zeros_
(
self
.
before_proj
.
bias
)
self
.
after_proj
=
Linear
(
hidden_size
,
hidden_size
)
init
.
zeros_
(
self
.
after_proj
.
weight
)
init
.
zeros_
(
self
.
after_proj
.
bias
)
def
forward
(
self
,
x
,
y
,
t0
,
t0_tmep
,
y_lens
,
c
,
tpe
):
if
self
.
block_index
==
0
:
# the first block
c
=
self
.
before_proj
(
c
)
c
=
self
.
copied_block
(
x
+
c
,
y
,
t0
,
t0_tmep
,
y_lens
,
tpe
)
c_skip
=
self
.
after_proj
(
c
)
else
:
# load from previous c and produce the c for skip connection
c
=
self
.
copied_block
(
c
,
y
,
t0
,
t0_tmep
,
y_lens
,
tpe
)
c_skip
=
self
.
after_proj
(
c
)
return
c
,
c_skip
# The implementation of ControlPixArtHalf net
class
ControlPixArtHalf
(
Module
):
# only support single res model
def
__init__
(
self
,
base_model
:
STDiT
,
copy_blocks_num
:
int
=
13
)
->
None
:
super
().
__init__
()
self
.
base_model
=
base_model
.
eval
()
self
.
controlnet
=
[]
self
.
copy_blocks_num
=
copy_blocks_num
self
.
total_blocks_num
=
len
(
base_model
.
blocks
)
for
p
in
self
.
base_model
.
parameters
():
p
.
requires_grad_
(
False
)
# Copy first copy_blocks_num block
for
i
in
range
(
copy_blocks_num
):
self
.
controlnet
.
append
(
ControlT2IDitBlockHalf
(
base_model
.
blocks
[
i
],
i
))
self
.
controlnet
=
nn
.
ModuleList
(
self
.
controlnet
)
def
__getattr__
(
self
,
name
:
str
)
->
Tensor
or
Module
:
if
name
in
[
'forward'
,
'forward_with_dpmsolver'
,
'forward_with_cfg'
,
'forward_c'
,
'load_state_dict'
]:
return
self
.
__dict__
[
name
]
elif
name
in
[
'base_model'
,
'controlnet'
]:
return
super
().
__getattr__
(
name
)
else
:
return
getattr
(
self
.
base_model
,
name
)
def
forward_c
(
self
,
c
):
# self.h, self.w = c.shape[-2]//self.patch_size, c.shape[-1]//self.patch_size
# pos_embed = torch.from_numpy(get_2d_sincos_pos_embed(self.pos_embed.shape[-1], (self.h, self.w), lewei_scale=self.lewei_scale, base_size=self.base_size)).unsqueeze(0).to(c.device).to(self.dtype)
x
=
self
.
x_embedder
(
c
)
# [B, N, C]
x
=
rearrange
(
x
,
"B (T S) C -> B T S C"
,
T
=
self
.
num_temporal
,
S
=
self
.
num_spatial
)
x
=
x
+
self
.
pos_embed
x
=
rearrange
(
x
,
"B T S C -> B (T S) C"
)
return
x
if
c
is
not
None
else
c
# def forward(self, x, t, c, **kwargs):
# return self.base_model(x, t, c=self.forward_c(c), **kwargs)
def
forward
(
self
,
x
,
timestep
,
y
,
mask
=
None
,
x_mask
=
None
,
c
=
None
):
# modify the original PixArtMS forward function
if
c
is
not
None
:
c
=
c
.
to
(
self
.
dtype
)
c
=
self
.
forward_c
(
c
)
"""
Forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) tensor of class labels
"""
x
=
x
.
to
(
self
.
dtype
)
timestep
=
timestep
.
to
(
self
.
dtype
)
y
=
y
.
to
(
self
.
dtype
)
# embedding
x
=
self
.
x_embedder
(
x
)
# [B, N, C]
x
=
rearrange
(
x
,
"B (T S) C -> B T S C"
,
T
=
self
.
num_temporal
,
S
=
self
.
num_spatial
)
x
=
x
+
self
.
pos_embed
x
=
rearrange
(
x
,
"B T S C -> B (T S) C"
)
# shard over the sequence dim if sp is enabled
if
self
.
enable_sequence_parallelism
:
x
=
split_forward_gather_backward
(
x
,
get_sequence_parallel_group
(),
dim
=
1
,
grad_scale
=
"down"
)
t
=
self
.
t_embedder
(
timestep
,
dtype
=
x
.
dtype
)
# [B, C]
t0
=
self
.
t_block
(
t
)
# [B, C]
t0_temp
=
self
.
t_block_temp
(
t
)
# [B, C]
y
=
self
.
y_embedder
(
y
,
self
.
training
)
# [B, 1, N_token, C]
if
mask
is
not
None
:
if
mask
.
shape
[
0
]
!=
y
.
shape
[
0
]:
mask
=
mask
.
repeat
(
y
.
shape
[
0
]
//
mask
.
shape
[
0
],
1
)
mask
=
mask
.
squeeze
(
1
).
squeeze
(
1
)
y
=
y
.
squeeze
(
1
).
masked_select
(
mask
.
unsqueeze
(
-
1
)
!=
0
).
view
(
1
,
-
1
,
x
.
shape
[
-
1
])
y_lens
=
mask
.
sum
(
dim
=
1
).
tolist
()
else
:
y_lens
=
[
y
.
shape
[
2
]]
*
y
.
shape
[
0
]
y
=
y
.
squeeze
(
1
).
view
(
1
,
-
1
,
x
.
shape
[
-
1
])
# define the first layer
# y_ori = y
tpe
=
self
.
pos_embed_temporal
x
=
auto_grad_checkpoint
(
self
.
base_model
.
blocks
[
0
],
x
,
y
,
t0
,
t0_temp
,
y_lens
,
tpe
)
# define the rest layers
# update c
for
index
in
range
(
1
,
self
.
copy_blocks_num
+
1
):
if
index
==
1
:
c
,
c_skip
=
auto_grad_checkpoint
(
self
.
controlnet
[
index
-
1
],
x
,
y
,
t0
,
t0_temp
,
y_lens
,
c
,
tpe
)
else
:
c
,
c_skip
=
auto_grad_checkpoint
(
self
.
controlnet
[
index
-
1
],
x
,
y
,
t0
,
t0_temp
,
y_lens
,
c
,
None
)
x
=
auto_grad_checkpoint
(
self
.
base_model
.
blocks
[
index
],
x
+
c_skip
,
y
,
t0
,
t0_temp
,
y_lens
,
None
)
# update x
for
index
in
range
(
self
.
copy_blocks_num
+
1
,
self
.
total_blocks_num
):
x
=
auto_grad_checkpoint
(
self
.
base_model
.
blocks
[
index
],
x
,
y
,
t0
,
t0_temp
,
y_lens
,
None
)
# final process
x
=
self
.
final_layer
(
x
,
t
)
# [B, N, C=T_p * H_p * W_p * C_out]
x
=
self
.
unpatchify
(
x
)
# [B, C_out, T, H, W]
# cast to float32 for better accuracy
x
=
x
.
to
(
torch
.
float32
)
return
x
def
load_state_dict
(
self
,
state_dict
:
Mapping
[
str
,
Any
],
strict
:
bool
=
True
):
if
all
((
k
.
startswith
(
'base_model'
)
or
k
.
startswith
(
'controlnet'
))
for
k
in
state_dict
.
keys
()):
return
super
().
load_state_dict
(
state_dict
,
strict
)
else
:
new_key
=
{}
for
k
in
state_dict
.
keys
():
new_key
[
k
]
=
re
.
sub
(
r
"(blocks\.\d+)(.*)"
,
r
"\1.base_block\2"
,
k
)
for
k
,
v
in
new_key
.
items
():
if
k
!=
v
:
print
(
f
"replace
{
k
}
to
{
v
}
"
)
state_dict
[
v
]
=
state_dict
.
pop
(
k
)
return
self
.
base_model
.
load_state_dict
(
state_dict
,
strict
)
def
unpatchify
(
self
,
x
):
"""
Args:
x (torch.Tensor): of shape [B, N, C]
Return:
x (torch.Tensor): of shape [B, C_out, T, H, W]
"""
N_t
,
N_h
,
N_w
=
[
self
.
input_size
[
i
]
//
self
.
patch_size
[
i
]
for
i
in
range
(
3
)]
T_p
,
H_p
,
W_p
=
self
.
patch_size
x
=
rearrange
(
x
,
"B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)"
,
N_t
=
N_t
,
N_h
=
N_h
,
N_w
=
N_w
,
T_p
=
T_p
,
H_p
=
H_p
,
W_p
=
W_p
,
C_out
=
self
.
out_channels
,
)
return
x
def
unpatchify_old
(
self
,
x
):
c
=
self
.
out_channels
t
,
h
,
w
=
[
self
.
input_size
[
i
]
//
self
.
patch_size
[
i
]
for
i
in
range
(
3
)]
pt
,
ph
,
pw
=
self
.
patch_size
x
=
x
.
reshape
(
shape
=
(
x
.
shape
[
0
],
t
,
h
,
w
,
pt
,
ph
,
pw
,
c
))
x
=
rearrange
(
x
,
"n t h w r p q c -> n c t r h p w q"
)
imgs
=
x
.
reshape
(
shape
=
(
x
.
shape
[
0
],
c
,
t
*
pt
,
h
*
ph
,
w
*
pw
))
return
imgs
def
get_spatial_pos_embed
(
self
,
grid_size
=
None
):
if
grid_size
is
None
:
grid_size
=
self
.
input_size
[
1
:]
pos_embed
=
get_2d_sincos_pos_embed
(
self
.
hidden_size
,
(
grid_size
[
0
]
//
self
.
patch_size
[
1
],
grid_size
[
1
]
//
self
.
patch_size
[
2
]),
scale
=
self
.
space_scale
,
)
pos_embed
=
torch
.
from_numpy
(
pos_embed
).
float
().
unsqueeze
(
0
).
requires_grad_
(
False
)
return
pos_embed
def
get_temporal_pos_embed
(
self
):
pos_embed
=
get_1d_sincos_pos_embed
(
self
.
hidden_size
,
self
.
input_size
[
0
]
//
self
.
patch_size
[
0
],
scale
=
self
.
time_scale
,
)
pos_embed
=
torch
.
from_numpy
(
pos_embed
).
float
().
unsqueeze
(
0
).
requires_grad_
(
False
)
return
pos_embed
def
freeze_not_temporal
(
self
):
for
n
,
p
in
self
.
named_parameters
():
if
"attn_temp"
not
in
n
:
p
.
requires_grad
=
False
def
freeze_text
(
self
):
for
n
,
p
in
self
.
named_parameters
():
if
"cross_attn"
in
n
:
p
.
requires_grad
=
False
def
initialize_temporal
(
self
):
for
block
in
self
.
blocks
:
nn
.
init
.
constant_
(
block
.
attn_temp
.
proj
.
weight
,
0
)
nn
.
init
.
constant_
(
block
.
attn_temp
.
proj
.
bias
,
0
)
def
initialize_weights
(
self
):
# Initialize transformer layers:
def
_basic_init
(
module
):
if
isinstance
(
module
,
nn
.
Linear
):
torch
.
nn
.
init
.
xavier_uniform_
(
module
.
weight
)
if
module
.
bias
is
not
None
:
nn
.
init
.
constant_
(
module
.
bias
,
0
)
self
.
apply
(
_basic_init
)
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w
=
self
.
x_embedder
.
proj
.
weight
.
data
nn
.
init
.
xavier_uniform_
(
w
.
view
([
w
.
shape
[
0
],
-
1
]))
# Initialize timestep embedding MLP:
nn
.
init
.
normal_
(
self
.
t_embedder
.
mlp
[
0
].
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
t_embedder
.
mlp
[
2
].
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
t_block
[
1
].
weight
,
std
=
0.02
)
# Initialize caption embedding MLP:
nn
.
init
.
normal_
(
self
.
y_embedder
.
y_proj
.
fc1
.
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
y_embedder
.
y_proj
.
fc2
.
weight
,
std
=
0.02
)
# Zero-out adaLN modulation layers in PixArt blocks:
for
block
in
self
.
blocks
:
nn
.
init
.
constant_
(
block
.
cross_attn
.
proj
.
weight
,
0
)
nn
.
init
.
constant_
(
block
.
cross_attn
.
proj
.
bias
,
0
)
# Zero-out output layers:
nn
.
init
.
constant_
(
self
.
final_layer
.
linear
.
weight
,
0
)
nn
.
init
.
constant_
(
self
.
final_layer
.
linear
.
bias
,
0
)
utils_data/opensora/models/stdit/stdit_freq.py
0 → 100644
View file @
1f5da520
import
numpy
as
np
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
from
einops
import
rearrange
from
timm.models.layers
import
DropPath
from
timm.models.vision_transformer
import
Mlp
from
opensora.acceleration.checkpoint
import
auto_grad_checkpoint
from
opensora.acceleration.communications
import
gather_forward_split_backward
,
split_forward_gather_backward
from
opensora.acceleration.parallel_states
import
get_sequence_parallel_group
from
opensora.models.layers.blocks
import
(
Attention
,
CaptionEmbedder
,
MultiHeadCrossAttention
,
PatchEmbed3D
,
SeqParallelAttention
,
SeqParallelMultiHeadCrossAttention
,
T2IFinalLayer
,
TimestepEmbedder
,
approx_gelu
,
get_1d_sincos_pos_embed
,
get_2d_sincos_pos_embed
,
get_layernorm
,
t2i_modulate
,
SpatialFrequencyBlcok
,
TemporalFrequencyBlock
,
Encoder_3D
,
)
from
opensora.registry
import
MODELS
from
opensora.utils.ckpt_utils
import
load_checkpoint
from
opensora.models.vsr.sfr_lftg
import
SpatialFeatureRefiner
,
LFTemporalGuider
from
opensora.models.vsr.fdie_arch
import
FrequencyDecoupledInfoExtractor
class
STDiTBlock
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
num_heads
,
d_s
=
None
,
d_t
=
None
,
mlp_ratio
=
4.0
,
drop_path
=
0.0
,
enable_flashattn
=
False
,
enable_layernorm_kernel
=
False
,
enable_sequence_parallelism
=
False
,
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
enable_flashattn
=
enable_flashattn
self
.
_enable_sequence_parallelism
=
enable_sequence_parallelism
if
enable_sequence_parallelism
:
self
.
attn_cls
=
SeqParallelAttention
self
.
mha_cls
=
SeqParallelMultiHeadCrossAttention
else
:
# here
self
.
attn_cls
=
Attention
self
.
mha_cls
=
MultiHeadCrossAttention
self
.
norm1
=
get_layernorm
(
hidden_size
,
eps
=
1e-6
,
affine
=
False
,
use_kernel
=
enable_layernorm_kernel
)
self
.
attn
=
self
.
attn_cls
(
hidden_size
,
num_heads
=
num_heads
,
qkv_bias
=
True
,
enable_flashattn
=
enable_flashattn
,
)
self
.
cross_attn
=
self
.
mha_cls
(
hidden_size
,
num_heads
)
self
.
norm2
=
get_layernorm
(
hidden_size
,
eps
=
1e-6
,
affine
=
False
,
use_kernel
=
enable_layernorm_kernel
)
self
.
mlp
=
Mlp
(
in_features
=
hidden_size
,
hidden_features
=
int
(
hidden_size
*
mlp_ratio
),
act_layer
=
approx_gelu
,
drop
=
0
)
self
.
drop_path
=
DropPath
(
drop_path
)
if
drop_path
>
0.0
else
nn
.
Identity
()
self
.
scale_shift_table
=
nn
.
Parameter
(
torch
.
randn
(
6
,
hidden_size
)
/
hidden_size
**
0.5
)
# temporal attention
self
.
d_s
=
d_s
self
.
d_t
=
d_t
if
self
.
_enable_sequence_parallelism
:
sp_size
=
dist
.
get_world_size
(
get_sequence_parallel_group
())
# make sure d_t is divisible by sp_size
assert
d_t
%
sp_size
==
0
self
.
d_t
=
d_t
//
sp_size
self
.
attn_temp
=
self
.
attn_cls
(
hidden_size
,
num_heads
=
num_heads
,
qkv_bias
=
True
,
enable_flashattn
=
self
.
enable_flashattn
,
)
# frequency block for spatial & temporal (New)
self
.
sfr
=
SpatialFeatureRefiner
(
hidden_channels
=
hidden_size
)
self
.
lftg
=
LFTemporalGuider
(
d_model
=
hidden_size
,
num_heads
=
num_heads
)
def
forward
(
self
,
x
,
y
,
t
,
mask
=
None
,
tpe
=
None
,
hf_fea
=
None
,
lf_fea
=
None
,
temp_fea
=
None
):
B
,
_
,
_
=
x
.
shape
shift_msa
,
scale_msa
,
gate_msa
,
shift_mlp
,
scale_mlp
,
gate_mlp
=
(
self
.
scale_shift_table
[
None
]
+
t
.
reshape
(
B
,
6
,
-
1
)
).
chunk
(
6
,
dim
=
1
)
x_m
=
t2i_modulate
(
self
.
norm1
(
x
),
shift_msa
,
scale_msa
)
# spatial feature refiner (New)
x_s
=
self
.
sfr
(
hf_fea
,
lf_fea
,
x_m
)
# spatial branch
x_s
=
rearrange
(
x_s
,
"B (T S) C -> (B T) S C"
,
T
=
self
.
d_t
,
S
=
self
.
d_s
)
x_s
=
self
.
attn
(
x_s
)
x_s
=
rearrange
(
x_s
,
"(B T) S C -> B (T S) C"
,
T
=
self
.
d_t
,
S
=
self
.
d_s
)
x
=
x
+
self
.
drop_path
(
gate_msa
*
x_s
)
# LF temporal guider (New)
x_t
=
rearrange
(
x
,
"B (T S) C -> (B S) T C"
,
T
=
self
.
d_t
,
S
=
self
.
d_s
)
temp_fea
=
rearrange
(
temp_fea
,
"B (T S) C -> (B S) T C"
,
T
=
self
.
d_t
,
S
=
self
.
d_s
)
# added tpe in the fdie
if
tpe
is
not
None
:
x_t
=
x_t
+
tpe
x_t
=
self
.
lftg
(
x_t
,
temp_fea
)
# temporal branch
x_t
=
self
.
attn_temp
(
x_t
)
x_t
=
rearrange
(
x_t
,
"(B S) T C -> B (T S) C"
,
T
=
self
.
d_t
,
S
=
self
.
d_s
)
x
=
x
+
self
.
drop_path
(
gate_msa
*
x_t
)
# cross attn
x
=
x
+
self
.
cross_attn
(
x
,
y
,
mask
)
# mlp
x
=
x
+
self
.
drop_path
(
gate_mlp
*
self
.
mlp
(
t2i_modulate
(
self
.
norm2
(
x
),
shift_mlp
,
scale_mlp
)))
return
x
@
MODELS
.
register_module
()
class
STDiT_freq
(
nn
.
Module
):
def
__init__
(
self
,
input_size
=
(
1
,
32
,
32
),
in_channels
=
4
,
patch_size
=
(
1
,
2
,
2
),
hidden_size
=
1152
,
depth
=
28
,
num_heads
=
16
,
mlp_ratio
=
4.0
,
class_dropout_prob
=
0.1
,
pred_sigma
=
True
,
drop_path
=
0.0
,
no_temporal_pos_emb
=
False
,
caption_channels
=
4096
,
model_max_length
=
120
,
dtype
=
torch
.
float32
,
space_scale
=
1.0
,
time_scale
=
1.0
,
freeze
=
None
,
enable_flashattn
=
False
,
enable_layernorm_kernel
=
False
,
enable_sequence_parallelism
=
False
,
):
super
().
__init__
()
self
.
pred_sigma
=
pred_sigma
self
.
in_channels
=
in_channels
self
.
out_channels
=
in_channels
*
2
if
pred_sigma
else
in_channels
self
.
hidden_size
=
hidden_size
self
.
patch_size
=
patch_size
self
.
input_size
=
input_size
num_patches
=
np
.
prod
([
input_size
[
i
]
//
patch_size
[
i
]
for
i
in
range
(
3
)])
self
.
num_patches
=
num_patches
self
.
num_temporal
=
input_size
[
0
]
//
patch_size
[
0
]
self
.
num_spatial
=
num_patches
//
self
.
num_temporal
self
.
num_heads
=
num_heads
self
.
dtype
=
dtype
self
.
no_temporal_pos_emb
=
no_temporal_pos_emb
self
.
depth
=
depth
self
.
mlp_ratio
=
mlp_ratio
self
.
enable_flashattn
=
enable_flashattn
self
.
enable_layernorm_kernel
=
enable_layernorm_kernel
self
.
space_scale
=
space_scale
self
.
time_scale
=
time_scale
self
.
register_buffer
(
"pos_embed"
,
self
.
get_spatial_pos_embed
())
self
.
register_buffer
(
"pos_embed_temporal"
,
self
.
get_temporal_pos_embed
())
self
.
x_embedder
=
PatchEmbed3D
(
patch_size
,
in_channels
,
hidden_size
)
self
.
t_embedder
=
TimestepEmbedder
(
hidden_size
)
self
.
t_block
=
nn
.
Sequential
(
nn
.
SiLU
(),
nn
.
Linear
(
hidden_size
,
6
*
hidden_size
,
bias
=
True
))
self
.
y_embedder
=
CaptionEmbedder
(
in_channels
=
caption_channels
,
hidden_size
=
hidden_size
,
uncond_prob
=
class_dropout_prob
,
act_layer
=
approx_gelu
,
token_num
=
model_max_length
,
)
drop_path
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path
,
depth
)]
self
.
blocks
=
nn
.
ModuleList
(
[
STDiTBlock
(
self
.
hidden_size
,
self
.
num_heads
,
mlp_ratio
=
self
.
mlp_ratio
,
drop_path
=
drop_path
[
i
],
enable_flashattn
=
self
.
enable_flashattn
,
enable_layernorm_kernel
=
self
.
enable_layernorm_kernel
,
enable_sequence_parallelism
=
enable_sequence_parallelism
,
d_t
=
self
.
num_temporal
,
d_s
=
self
.
num_spatial
,
)
for
i
in
range
(
self
.
depth
)
]
)
self
.
final_layer
=
T2IFinalLayer
(
hidden_size
,
np
.
prod
(
self
.
patch_size
),
self
.
out_channels
)
# frequency-decoupled information extractor
self
.
fdie
=
FrequencyDecoupledInfoExtractor
(
in_channels
=
3
,
hidden_channels
=
64
)
# high-frequency & low-frequency embedder
self
.
hf_embedder
=
PatchEmbed3D
(
patch_size
=
(
1
,
16
,
16
),
in_chans
=
3
,
embed_dim
=
hidden_size
)
self
.
lf_embedder
=
PatchEmbed3D
(
patch_size
=
(
1
,
16
,
16
),
in_chans
=
3
,
embed_dim
=
hidden_size
)
# init model
self
.
initialize_weights
()
self
.
initialize_temporal
()
# sequence parallel related configs
self
.
enable_sequence_parallelism
=
enable_sequence_parallelism
if
enable_sequence_parallelism
:
self
.
sp_rank
=
dist
.
get_rank
(
get_sequence_parallel_group
())
else
:
self
.
sp_rank
=
None
def
forward
(
self
,
x
,
timestep
,
y
,
mask
=
None
):
"""
Forward pass of STDiT.
Args:
x (torch.Tensor): latent representation of video; of shape [B, C, T, H, W]
timestep (torch.Tensor): diffusion time steps; of shape [B]
y (torch.Tensor): representation of prompts; of shape [B, 1, N_token, C]
mask (torch.Tensor): mask for selecting prompt tokens; of shape [B, N_token]
Returns:
x (torch.Tensor): output latent representation; of shape [B, C, T, H, W]
"""
x
=
x
.
to
(
self
.
dtype
)
timestep
=
timestep
.
to
(
self
.
dtype
)
y
=
y
.
to
(
self
.
dtype
)
# embedding
x
=
self
.
x_embedder
(
x
)
# [B, N, C]
x
=
rearrange
(
x
,
"B (T S) C -> B T S C"
,
T
=
self
.
num_temporal
,
S
=
self
.
num_spatial
)
x
=
x
+
self
.
pos_embed
x
=
rearrange
(
x
,
"B T S C -> B (T S) C"
)
# shard over the sequence dim if sp is enabled
if
self
.
enable_sequence_parallelism
:
x
=
split_forward_gather_backward
(
x
,
get_sequence_parallel_group
(),
dim
=
1
,
grad_scale
=
"down"
)
t
=
self
.
t_embedder
(
timestep
,
dtype
=
x
.
dtype
)
# [B, C]
t0
=
self
.
t_block
(
t
)
# [B, C]
y
=
self
.
y_embedder
(
y
,
self
.
training
)
# [B, 1, N_token, C]
if
mask
is
not
None
:
if
mask
.
shape
[
0
]
!=
y
.
shape
[
0
]:
mask
=
mask
.
repeat
(
y
.
shape
[
0
]
//
mask
.
shape
[
0
],
1
)
mask
=
mask
.
squeeze
(
1
).
squeeze
(
1
)
y
=
y
.
squeeze
(
1
).
masked_select
(
mask
.
unsqueeze
(
-
1
)
!=
0
).
view
(
1
,
-
1
,
x
.
shape
[
-
1
])
y_lens
=
mask
.
sum
(
dim
=
1
).
tolist
()
else
:
y_lens
=
[
y
.
shape
[
2
]]
*
y
.
shape
[
0
]
y
=
y
.
squeeze
(
1
).
view
(
1
,
-
1
,
x
.
shape
[
-
1
])
# blocks
for
i
,
block
in
enumerate
(
self
.
blocks
):
if
i
==
0
:
if
self
.
enable_sequence_parallelism
:
tpe
=
torch
.
chunk
(
self
.
pos_embed_temporal
,
dist
.
get_world_size
(
get_sequence_parallel_group
()),
dim
=
1
)[
self
.
sp_rank
].
contiguous
()
else
:
tpe
=
self
.
pos_embed_temporal
else
:
tpe
=
None
x
=
auto_grad_checkpoint
(
block
,
x
,
y
,
t0
,
y_lens
,
tpe
)
if
self
.
enable_sequence_parallelism
:
x
=
gather_forward_split_backward
(
x
,
get_sequence_parallel_group
(),
dim
=
1
,
grad_scale
=
"up"
)
# x.shape: [B, N, C]
# final process
x
=
self
.
final_layer
(
x
,
t
)
# [B, N, C=T_p * H_p * W_p * C_out]
x
=
self
.
unpatchify
(
x
)
# [B, C_out, T, H, W]
# cast to float32 for better accuracy
x
=
x
.
to
(
torch
.
float32
)
return
x
def
unpatchify
(
self
,
x
):
"""
Args:
x (torch.Tensor): of shape [B, N, C]
Return:
x (torch.Tensor): of shape [B, C_out, T, H, W]
"""
N_t
,
N_h
,
N_w
=
[
self
.
input_size
[
i
]
//
self
.
patch_size
[
i
]
for
i
in
range
(
3
)]
T_p
,
H_p
,
W_p
=
self
.
patch_size
x
=
rearrange
(
x
,
"B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)"
,
N_t
=
N_t
,
N_h
=
N_h
,
N_w
=
N_w
,
T_p
=
T_p
,
H_p
=
H_p
,
W_p
=
W_p
,
C_out
=
self
.
out_channels
,
)
return
x
def
unpatchify_old
(
self
,
x
):
c
=
self
.
out_channels
t
,
h
,
w
=
[
self
.
input_size
[
i
]
//
self
.
patch_size
[
i
]
for
i
in
range
(
3
)]
pt
,
ph
,
pw
=
self
.
patch_size
x
=
x
.
reshape
(
shape
=
(
x
.
shape
[
0
],
t
,
h
,
w
,
pt
,
ph
,
pw
,
c
))
x
=
rearrange
(
x
,
"n t h w r p q c -> n c t r h p w q"
)
imgs
=
x
.
reshape
(
shape
=
(
x
.
shape
[
0
],
c
,
t
*
pt
,
h
*
ph
,
w
*
pw
))
return
imgs
def
get_spatial_pos_embed
(
self
,
grid_size
=
None
):
if
grid_size
is
None
:
grid_size
=
self
.
input_size
[
1
:]
pos_embed
=
get_2d_sincos_pos_embed
(
self
.
hidden_size
,
(
grid_size
[
0
]
//
self
.
patch_size
[
1
],
grid_size
[
1
]
//
self
.
patch_size
[
2
]),
scale
=
self
.
space_scale
,
)
# 确保 pos_embed 是 NumPy ndarray 类型,并进行转换
pos_embed
=
np
.
array
(
pos_embed
)
# 确保 pos_embed 是 NumPy ndarray 类型
pos_embed
=
torch
.
tensor
(
pos_embed
,
dtype
=
torch
.
float32
).
unsqueeze
(
0
).
requires_grad_
(
False
)
return
pos_embed
def
get_temporal_pos_embed
(
self
):
pos_embed
=
get_1d_sincos_pos_embed
(
self
.
hidden_size
,
self
.
input_size
[
0
]
//
self
.
patch_size
[
0
],
scale
=
self
.
time_scale
,
)
pos_embed
=
torch
.
from_numpy
(
pos_embed
).
float
().
unsqueeze
(
0
).
requires_grad_
(
False
)
return
pos_embed
def
freeze_not_temporal
(
self
):
for
n
,
p
in
self
.
named_parameters
():
if
"attn_temp"
not
in
n
:
p
.
requires_grad
=
False
def
freeze_text
(
self
):
for
n
,
p
in
self
.
named_parameters
():
if
"cross_attn"
in
n
:
p
.
requires_grad
=
False
def
initialize_temporal
(
self
):
for
block
in
self
.
blocks
:
nn
.
init
.
constant_
(
block
.
attn_temp
.
proj
.
weight
,
0
)
nn
.
init
.
constant_
(
block
.
attn_temp
.
proj
.
bias
,
0
)
# nn.init.constant_(block.temporal_freq.proj.weight, 0)
# nn.init.constant_(block.temporal_freq.proj.bias, 0)
def
initialize_weights
(
self
):
# Initialize transformer layers:
def
_basic_init
(
module
):
if
isinstance
(
module
,
nn
.
Linear
):
torch
.
nn
.
init
.
xavier_uniform_
(
module
.
weight
)
if
module
.
bias
is
not
None
:
nn
.
init
.
constant_
(
module
.
bias
,
0
)
self
.
apply
(
_basic_init
)
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w
=
self
.
x_embedder
.
proj
.
weight
.
data
nn
.
init
.
xavier_uniform_
(
w
.
view
([
w
.
shape
[
0
],
-
1
]))
e
=
self
.
lf_embedder
.
proj
.
weight
.
data
nn
.
init
.
xavier_uniform_
(
e
.
view
([
e
.
shape
[
0
],
-
1
]))
r
=
self
.
hf_embedder
.
proj
.
weight
.
data
nn
.
init
.
xavier_uniform_
(
r
.
view
([
r
.
shape
[
0
],
-
1
]))
# Initialize timestep embedding MLP:
nn
.
init
.
normal_
(
self
.
t_embedder
.
mlp
[
0
].
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
t_embedder
.
mlp
[
2
].
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
t_block
[
1
].
weight
,
std
=
0.02
)
# Initialize caption embedding MLP:
nn
.
init
.
normal_
(
self
.
y_embedder
.
y_proj
.
fc1
.
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
y_embedder
.
y_proj
.
fc2
.
weight
,
std
=
0.02
)
# Zero-out adaLN modulation layers in PixArt blocks:
for
block
in
self
.
blocks
:
nn
.
init
.
constant_
(
block
.
cross_attn
.
proj
.
weight
,
0
)
nn
.
init
.
constant_
(
block
.
cross_attn
.
proj
.
bias
,
0
)
# Zero-out output layers:
nn
.
init
.
constant_
(
self
.
final_layer
.
linear
.
weight
,
0
)
nn
.
init
.
constant_
(
self
.
final_layer
.
linear
.
bias
,
0
)
@
MODELS
.
register_module
(
"STDiT-freq-XL/2"
)
def
STDiT_freq_XL_2
(
from_pretrained
=
None
,
**
kwargs
):
model
=
STDiT_freq
(
depth
=
28
,
hidden_size
=
1152
,
patch_size
=
(
1
,
2
,
2
),
num_heads
=
16
,
**
kwargs
)
if
from_pretrained
is
not
None
:
load_checkpoint
(
model
,
from_pretrained
)
return
model
utils_data/opensora/models/stdit/stdit_mmdit.py
0 → 100644
View file @
1f5da520
import
numpy
as
np
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
from
einops
import
rearrange
from
timm.models.layers
import
DropPath
from
timm.models.vision_transformer
import
Mlp
from
opensora.acceleration.checkpoint
import
auto_grad_checkpoint
from
opensora.acceleration.communications
import
gather_forward_split_backward
,
split_forward_gather_backward
from
opensora.acceleration.parallel_states
import
get_sequence_parallel_group
from
opensora.models.layers.blocks
import
(
Attention
,
Attention_QKNorm_RoPE
,
MaskedSelfAttention
,
CaptionEmbedder
,
MultiHeadCrossAttention
,
MaskedMultiHeadCrossAttention
,
PatchEmbed3D
,
SeqParallelAttention
,
SeqParallelMultiHeadCrossAttention
,
T2IFinalLayer
,
TimestepEmbedder
,
approx_gelu
,
get_1d_sincos_pos_embed
,
get_2d_sincos_pos_embed
,
get_layernorm
,
t2i_modulate
,
)
from
opensora.registry
import
MODELS
from
opensora.utils.ckpt_utils
import
load_checkpoint
# import ipdb
from
opensora.models.layers.timm_uvit
import
trunc_normal_
class
STDiTBlock
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
num_heads
,
d_s
=
None
,
d_t
=
None
,
mlp_ratio
=
4.0
,
drop_path
=
0.0
,
enable_flashattn
=
False
,
enable_layernorm_kernel
=
False
,
enable_sequence_parallelism
=
False
,
qk_norm
=
False
,
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
enable_flashattn
=
enable_flashattn
self
.
_enable_sequence_parallelism
=
enable_sequence_parallelism
if
enable_sequence_parallelism
:
self
.
attn_cls
=
SeqParallelAttention
self
.
mha_cls
=
SeqParallelMultiHeadCrossAttention
else
:
# here
self
.
self_masked_attn
=
MaskedSelfAttention
self
.
attn_cls
=
Attention_QKNorm_RoPE
self
.
mha_cls
=
MaskedMultiHeadCrossAttention
self
.
norm1
=
get_layernorm
(
hidden_size
,
eps
=
1e-6
,
affine
=
False
,
use_kernel
=
enable_layernorm_kernel
)
self
.
norm1_y
=
get_layernorm
(
hidden_size
,
eps
=
1e-6
,
affine
=
False
,
use_kernel
=
enable_layernorm_kernel
)
self
.
norm2
=
get_layernorm
(
hidden_size
,
eps
=
1e-6
,
affine
=
False
,
use_kernel
=
enable_layernorm_kernel
)
self
.
norm2_y
=
get_layernorm
(
hidden_size
,
eps
=
1e-6
,
affine
=
False
,
use_kernel
=
enable_layernorm_kernel
)
self
.
attn
=
self
.
self_masked_attn
(
hidden_size
,
num_heads
=
num_heads
,
qkv_bias
=
True
,
enable_flashattn
=
enable_flashattn
,
qk_norm
=
qk_norm
,
)
self
.
cross_attn
=
self
.
mha_cls
(
hidden_size
,
num_heads
)
self
.
norm3
=
get_layernorm
(
hidden_size
,
eps
=
1e-6
,
affine
=
False
,
use_kernel
=
enable_layernorm_kernel
)
self
.
norm3_y
=
get_layernorm
(
hidden_size
,
eps
=
1e-6
,
affine
=
False
,
use_kernel
=
enable_layernorm_kernel
)
self
.
mlp
=
Mlp
(
in_features
=
hidden_size
,
hidden_features
=
int
(
hidden_size
*
mlp_ratio
),
act_layer
=
approx_gelu
,
drop
=
0
)
self
.
mlp_y
=
Mlp
(
in_features
=
hidden_size
,
hidden_features
=
int
(
hidden_size
*
mlp_ratio
),
act_layer
=
approx_gelu
,
drop
=
0
)
self
.
drop_path
=
DropPath
(
drop_path
)
if
drop_path
>
0.0
else
nn
.
Identity
()
self
.
scale_shift_table
=
nn
.
Parameter
(
torch
.
randn
(
6
,
hidden_size
)
/
hidden_size
**
0.5
)
self
.
scale_shift_table_y
=
nn
.
Parameter
(
torch
.
randn
(
6
,
hidden_size
)
/
hidden_size
**
0.5
)
self
.
scale_shift_table_temp
=
nn
.
Parameter
(
torch
.
randn
(
3
,
hidden_size
)
/
hidden_size
**
0.5
)
self
.
scale_shift_table_y_temp
=
nn
.
Parameter
(
torch
.
randn
(
3
,
hidden_size
)
/
hidden_size
**
0.5
)
# temporal attention
self
.
d_s
=
d_s
self
.
d_t
=
d_t
if
self
.
_enable_sequence_parallelism
:
sp_size
=
dist
.
get_world_size
(
get_sequence_parallel_group
())
# make sure d_t is divisible by sp_size
assert
d_t
%
sp_size
==
0
self
.
d_t
=
d_t
//
sp_size
self
.
attn_temp
=
self
.
attn_cls
(
hidden_size
,
num_heads
=
num_heads
,
qkv_bias
=
True
,
enable_flashattn
=
self
.
enable_flashattn
,
qk_norm
=
qk_norm
,
)
def
forward
(
self
,
x
,
y
,
t
,
t_y
,
t_tmep
,
t_y_tmep
,
mask
=
None
,
tpe
=
None
):
B
,
N
,
C
=
x
.
shape
L
=
y
.
shape
[
2
]
# y: B T L C, mask: B T L
x
=
rearrange
(
x
,
"B (T S) C -> B T S C"
,
T
=
self
.
d_t
,
S
=
self
.
d_s
)
x_mask
=
torch
.
ones
(
x
.
shape
[:
3
],
device
=
x
.
device
,
dtype
=
x
.
dtype
)
# B T S
shift_msa
,
scale_msa
,
gate_msa
,
shift_mlp
,
scale_mlp
,
gate_mlp
=
(
self
.
scale_shift_table
[
None
]
+
t
.
reshape
(
B
,
6
,
-
1
)
).
chunk
(
6
,
dim
=
1
)
shift_msa_y
,
scale_msa_y
,
gate_msa_y
,
shift_mlp_y
,
scale_mlp_y
,
gate_mlp_y
=
(
self
.
scale_shift_table_y
[
None
]
+
t_y
.
reshape
(
B
,
6
,
-
1
)
).
chunk
(
6
,
dim
=
1
)
shift_msa_temp
,
scale_msa_temp
,
gate_msa_temp
=
(
self
.
scale_shift_table_temp
[
None
]
+
t_tmep
.
reshape
(
B
,
3
,
-
1
)
).
chunk
(
3
,
dim
=
1
)
shift_msa_y_temp
,
scale_msa_y_temp
,
gate_msa_y_temp
=
(
self
.
scale_shift_table_y_temp
[
None
]
+
t_y_tmep
.
reshape
(
B
,
3
,
-
1
)
).
chunk
(
3
,
dim
=
1
)
x
=
rearrange
(
x
,
"B T S C -> B (T S) C"
)
y
=
rearrange
(
y
,
"B T L C -> B (T L) C"
)
x_m
=
t2i_modulate
(
self
.
norm1
(
x
),
shift_msa
,
scale_msa
)
y_m
=
t2i_modulate
(
self
.
norm1_y
(
y
),
shift_msa_y
,
scale_msa_y
)
x_m
=
rearrange
(
x_m
,
"B (T S) C -> B T S C"
,
T
=
self
.
d_t
,
S
=
self
.
d_s
)
y_m
=
rearrange
(
y_m
,
"B (T L) C -> B T L C"
,
T
=
self
.
d_t
,
L
=
L
)
xy_m
=
torch
.
cat
([
x_m
,
y_m
],
dim
=
2
)
xy_mask
=
torch
.
cat
([
x_mask
,
mask
],
dim
=
2
)
xy_mask
=
rearrange
(
xy_mask
,
"B T N -> (B T) N"
)
# spatial branch
xy_s
=
rearrange
(
xy_m
,
"B T N C -> (B T) N C"
)
xy_s
=
self
.
attn
(
xy_s
,
xy_mask
)
xy_s
=
rearrange
(
xy_s
,
"(B T) N C -> B T N C"
,
B
=
B
,
T
=
self
.
d_t
)
x_s
=
xy_s
[:,
:,
:
self
.
d_s
,
:]
y_s
=
xy_s
[:,
:,
self
.
d_s
:,
:]
x_s
=
rearrange
(
x_s
,
"B T S C -> B (T S) C"
)
y_s
=
rearrange
(
y_s
,
"B T L C -> B (T L) C"
)
x
=
x
+
self
.
drop_path
(
gate_msa
*
x_s
)
y
=
y
+
self
.
drop_path
(
gate_msa_y
*
y_s
)
x_t
=
t2i_modulate
(
self
.
norm2
(
x
),
shift_msa_temp
,
scale_msa_temp
)
y_t
=
t2i_modulate
(
self
.
norm2_y
(
y
),
shift_msa_y_temp
,
scale_msa_y_temp
)
x_t
=
rearrange
(
x_t
,
"B (T S) C -> B T S C"
,
T
=
self
.
d_t
,
S
=
self
.
d_s
)
y_t
=
rearrange
(
y_t
,
"B (T L) C -> B T L C"
,
T
=
self
.
d_t
,
L
=
L
)
xy_t
=
torch
.
cat
([
x_t
,
y_t
],
dim
=
2
)
# temporal branch
xy_t
=
rearrange
(
xy_t
,
"B T N C -> (B N) T C"
)
if
tpe
is
not
None
:
xy_t
=
xy_t
+
tpe
xy_t
=
self
.
attn_temp
(
xy_t
)
xy_t
=
rearrange
(
xy_t
,
"(B N) T C -> B T N C"
,
B
=
B
,
N
=
self
.
d_s
+
L
)
x_t
=
xy_t
[:,
:,
:
self
.
d_s
,
:]
y_t
=
xy_t
[:,
:,
self
.
d_s
:,
:]
x_t
=
rearrange
(
x_t
,
"B T S C -> B (T S) C"
)
y_t
=
rearrange
(
y_t
,
"B T L C -> B (T L) C"
)
x
=
x
+
self
.
drop_path
(
gate_msa_temp
*
x_t
)
y
=
y
+
self
.
drop_path
(
gate_msa_y_temp
*
y_t
)
x
=
rearrange
(
x
,
"B (T S) C -> (B T) S C"
,
T
=
self
.
d_t
,
S
=
self
.
d_s
)
y
=
rearrange
(
y
,
"B (T L) C -> (B T) L C"
,
T
=
self
.
d_t
,
L
=
L
)
mask
=
rearrange
(
mask
,
"B T L -> (B T) L"
)
# cross attn
x
=
x
+
self
.
cross_attn
(
x
,
y
,
mask
)
x
=
rearrange
(
x
,
"(B T) S C -> B (T S) C"
,
B
=
B
,
T
=
self
.
d_t
)
y
=
rearrange
(
y
,
"(B T) L C -> B (T L) C"
,
B
=
B
,
T
=
self
.
d_t
)
# mlp
x
=
x
+
self
.
drop_path
(
gate_mlp
*
self
.
mlp
(
t2i_modulate
(
self
.
norm3
(
x
),
shift_mlp
,
scale_mlp
)))
y
=
y
+
self
.
drop_path
(
gate_mlp_y
*
self
.
mlp_y
(
t2i_modulate
(
self
.
norm3_y
(
y
),
shift_mlp_y
,
scale_mlp_y
)))
y
=
rearrange
(
y
,
"B (T L) C -> B T L C"
,
T
=
self
.
d_t
,
L
=
L
)
return
x
,
y
@
MODELS
.
register_module
()
class
STDiT_MMDiT
(
nn
.
Module
):
def
__init__
(
self
,
input_size
=
(
1
,
32
,
32
),
in_channels
=
4
,
patch_size
=
(
1
,
2
,
2
),
hidden_size
=
1152
,
depth
=
28
,
num_heads
=
16
,
mlp_ratio
=
4.0
,
class_dropout_prob
=
0.1
,
pred_sigma
=
True
,
drop_path
=
0.0
,
no_temporal_pos_emb
=
False
,
caption_channels
=
4096
,
model_max_length
=
120
,
dtype
=
torch
.
float32
,
space_scale
=
1.0
,
time_scale
=
1.0
,
freeze
=
None
,
enable_flashattn
=
False
,
enable_layernorm_kernel
=
False
,
enable_sequence_parallelism
=
False
,
qk_norm
=
False
,
):
super
().
__init__
()
self
.
pred_sigma
=
pred_sigma
self
.
in_channels
=
in_channels
self
.
out_channels
=
in_channels
*
2
if
pred_sigma
else
in_channels
self
.
hidden_size
=
hidden_size
self
.
patch_size
=
patch_size
self
.
input_size
=
input_size
num_patches
=
np
.
prod
([
input_size
[
i
]
//
patch_size
[
i
]
for
i
in
range
(
3
)])
self
.
num_patches
=
num_patches
self
.
num_temporal
=
input_size
[
0
]
//
patch_size
[
0
]
self
.
num_spatial
=
num_patches
//
self
.
num_temporal
self
.
num_heads
=
num_heads
self
.
dtype
=
dtype
self
.
no_temporal_pos_emb
=
no_temporal_pos_emb
self
.
depth
=
depth
self
.
mlp_ratio
=
mlp_ratio
self
.
enable_flashattn
=
enable_flashattn
self
.
enable_layernorm_kernel
=
enable_layernorm_kernel
self
.
space_scale
=
space_scale
self
.
time_scale
=
time_scale
self
.
register_buffer
(
"pos_embed"
,
self
.
get_spatial_pos_embed
())
self
.
register_buffer
(
"pos_embed_temporal"
,
self
.
get_temporal_pos_embed
())
self
.
x_embedder
=
PatchEmbed3D
(
patch_size
,
in_channels
,
hidden_size
)
self
.
t_embedder
=
TimestepEmbedder
(
hidden_size
)
self
.
t_block
=
nn
.
Sequential
(
nn
.
SiLU
(),
nn
.
Linear
(
hidden_size
,
6
*
hidden_size
,
bias
=
True
))
self
.
t_block_y
=
nn
.
Sequential
(
nn
.
SiLU
(),
nn
.
Linear
(
hidden_size
,
6
*
hidden_size
,
bias
=
True
))
self
.
t_block_temp
=
nn
.
Sequential
(
nn
.
SiLU
(),
nn
.
Linear
(
hidden_size
,
3
*
hidden_size
,
bias
=
True
))
self
.
t_block_y_temp
=
nn
.
Sequential
(
nn
.
SiLU
(),
nn
.
Linear
(
hidden_size
,
3
*
hidden_size
,
bias
=
True
))
self
.
y_embedder
=
CaptionEmbedder
(
in_channels
=
caption_channels
,
hidden_size
=
hidden_size
,
uncond_prob
=
class_dropout_prob
,
act_layer
=
approx_gelu
,
token_num
=
model_max_length
,
)
drop_path
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path
,
depth
)]
self
.
blocks
=
nn
.
ModuleList
(
[
STDiTBlock
(
self
.
hidden_size
,
self
.
num_heads
,
mlp_ratio
=
self
.
mlp_ratio
,
drop_path
=
drop_path
[
i
],
enable_flashattn
=
self
.
enable_flashattn
,
enable_layernorm_kernel
=
self
.
enable_layernorm_kernel
,
enable_sequence_parallelism
=
enable_sequence_parallelism
,
d_t
=
self
.
num_temporal
,
d_s
=
self
.
num_spatial
,
qk_norm
=
qk_norm
,
)
for
i
in
range
(
self
.
depth
)
]
)
self
.
final_layer
=
T2IFinalLayer
(
hidden_size
,
np
.
prod
(
self
.
patch_size
),
self
.
out_channels
)
# init model
self
.
initialize_weights
()
self
.
initialize_temporal
()
if
freeze
is
not
None
:
assert
freeze
in
[
"not_temporal"
,
"text"
,
"not_attn"
]
if
freeze
==
"not_temporal"
:
self
.
freeze_not_temporal
()
elif
freeze
==
"text"
:
self
.
freeze_text
()
elif
freeze
==
"not_attn"
:
self
.
freeze_not_attn
()
# sequence parallel related configs
self
.
enable_sequence_parallelism
=
enable_sequence_parallelism
if
enable_sequence_parallelism
:
self
.
sp_rank
=
dist
.
get_rank
(
get_sequence_parallel_group
())
else
:
self
.
sp_rank
=
None
def
forward
(
self
,
x
,
timestep
,
y
,
mask
=
None
):
"""
Forward pass of STDiT.
Args:
x (torch.Tensor): latent representation of video; of shape [B, C, T, H, W]
timestep (torch.Tensor): diffusion time steps; of shape [B]
y (torch.Tensor): representation of prompts; of shape [B, 1, N_token, C]
mask (torch.Tensor): mask for selecting prompt tokens; of shape [B, N_token]
Returns:
x (torch.Tensor): output latent representation; of shape [B, C, T, H, W]
"""
x
=
x
.
to
(
self
.
dtype
)
timestep
=
timestep
.
to
(
self
.
dtype
)
y
=
y
.
to
(
self
.
dtype
)
# embedding
x
=
self
.
x_embedder
(
x
)
# [B, (THW), C]
x
=
rearrange
(
x
,
"B (T S) C -> B T S C"
,
T
=
self
.
num_temporal
,
S
=
self
.
num_spatial
)
x
=
x
+
self
.
pos_embed
x
=
rearrange
(
x
,
"B T S C -> B (T S) C"
)
# shard over the sequence dim if sp is enabled
if
self
.
enable_sequence_parallelism
:
x
=
split_forward_gather_backward
(
x
,
get_sequence_parallel_group
(),
dim
=
1
,
grad_scale
=
"down"
)
t
=
self
.
t_embedder
(
timestep
,
dtype
=
x
.
dtype
)
# [B, C]
t0
=
self
.
t_block
(
t
)
# [B, C]
t_y
=
self
.
t_block_y
(
t
)
t0_tmep
=
self
.
t_block_temp
(
t
)
# [B, C]
t_y_tmep
=
self
.
t_block_y_temp
(
t
)
y
=
self
.
y_embedder
(
y
,
self
.
training
)
# [B, 1, N_token, C]
if
mask
is
not
None
:
if
mask
.
shape
[
0
]
!=
y
.
shape
[
0
]:
mask
=
mask
.
repeat
(
y
.
shape
[
0
]
//
mask
.
shape
[
0
],
1
)
y
=
y
.
repeat
(
1
,
self
.
num_temporal
,
1
,
1
)
# B T L C
mask
=
mask
.
unsqueeze
(
1
).
repeat
(
1
,
self
.
num_temporal
,
1
)
# B T L
# mask = mask.squeeze(1).squeeze(1)
# y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
# y_lens = mask.sum(dim=1).tolist()
else
:
y_lens
=
[
y
.
shape
[
2
]]
*
y
.
shape
[
0
]
y
=
y
.
squeeze
(
1
).
view
(
1
,
-
1
,
x
.
shape
[
-
1
])
# blocks
for
i
,
block
in
enumerate
(
self
.
blocks
):
if
i
==
0
:
if
self
.
enable_sequence_parallelism
:
tpe
=
torch
.
chunk
(
self
.
pos_embed_temporal
,
dist
.
get_world_size
(
get_sequence_parallel_group
()),
dim
=
1
)[
self
.
sp_rank
].
contiguous
()
else
:
tpe
=
self
.
pos_embed_temporal
else
:
tpe
=
None
x
,
y
=
auto_grad_checkpoint
(
block
,
x
,
y
,
t0
,
t_y
,
t0_tmep
,
t_y_tmep
,
mask
,
tpe
)
if
self
.
enable_sequence_parallelism
:
x
=
gather_forward_split_backward
(
x
,
get_sequence_parallel_group
(),
dim
=
1
,
grad_scale
=
"up"
)
# x.shape: [B, N, C]
# final process
x
=
self
.
final_layer
(
x
,
t
)
# [B, N, C=T_p * H_p * W_p * C_out]
x
=
self
.
unpatchify
(
x
)
# [B, C_out, T, H, W]
# cast to float32 for better accuracy
x
=
x
.
to
(
torch
.
float32
)
return
x
def
unpatchify
(
self
,
x
):
"""
Args:
x (torch.Tensor): of shape [B, N, C]
Return:
x (torch.Tensor): of shape [B, C_out, T, H, W]
"""
N_t
,
N_h
,
N_w
=
[
self
.
input_size
[
i
]
//
self
.
patch_size
[
i
]
for
i
in
range
(
3
)]
T_p
,
H_p
,
W_p
=
self
.
patch_size
x
=
rearrange
(
x
,
"B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)"
,
N_t
=
N_t
,
N_h
=
N_h
,
N_w
=
N_w
,
T_p
=
T_p
,
H_p
=
H_p
,
W_p
=
W_p
,
C_out
=
self
.
out_channels
,
)
return
x
def
unpatchify_old
(
self
,
x
):
c
=
self
.
out_channels
t
,
h
,
w
=
[
self
.
input_size
[
i
]
//
self
.
patch_size
[
i
]
for
i
in
range
(
3
)]
pt
,
ph
,
pw
=
self
.
patch_size
x
=
x
.
reshape
(
shape
=
(
x
.
shape
[
0
],
t
,
h
,
w
,
pt
,
ph
,
pw
,
c
))
x
=
rearrange
(
x
,
"n t h w r p q c -> n c t r h p w q"
)
imgs
=
x
.
reshape
(
shape
=
(
x
.
shape
[
0
],
c
,
t
*
pt
,
h
*
ph
,
w
*
pw
))
return
imgs
def
get_spatial_pos_embed
(
self
,
grid_size
=
None
):
if
grid_size
is
None
:
grid_size
=
self
.
input_size
[
1
:]
pos_embed
=
get_2d_sincos_pos_embed
(
self
.
hidden_size
,
(
grid_size
[
0
]
//
self
.
patch_size
[
1
],
grid_size
[
1
]
//
self
.
patch_size
[
2
]),
scale
=
self
.
space_scale
,
)
pos_embed
=
torch
.
from_numpy
(
pos_embed
).
float
().
unsqueeze
(
0
).
requires_grad_
(
False
)
return
pos_embed
def
get_temporal_pos_embed
(
self
):
pos_embed
=
get_1d_sincos_pos_embed
(
self
.
hidden_size
,
self
.
input_size
[
0
]
//
self
.
patch_size
[
0
],
scale
=
self
.
time_scale
,
)
pos_embed
=
torch
.
from_numpy
(
pos_embed
).
float
().
unsqueeze
(
0
).
requires_grad_
(
False
)
return
pos_embed
def
freeze_not_temporal
(
self
):
for
n
,
p
in
self
.
named_parameters
():
if
"attn_temp"
not
in
n
:
p
.
requires_grad
=
False
def
freeze_text
(
self
):
for
n
,
p
in
self
.
named_parameters
():
if
"cross_attn"
in
n
:
p
.
requires_grad
=
False
def
freeze_not_attn
(
self
):
for
n
,
p
in
self
.
named_parameters
():
if
"attn"
not
in
n
:
p
.
requires_grad
=
False
if
"cross_attn"
in
n
or
"attn_temp"
in
n
:
p
.
requires_grad
=
False
def
initialize_temporal
(
self
):
for
block
in
self
.
blocks
:
nn
.
init
.
constant_
(
block
.
attn_temp
.
proj
.
weight
,
0
)
nn
.
init
.
constant_
(
block
.
attn_temp
.
proj
.
bias
,
0
)
def
initialize_weights
(
self
):
# Initialize transformer layers:
def
_basic_init
(
module
):
if
isinstance
(
module
,
nn
.
Linear
):
torch
.
nn
.
init
.
xavier_uniform_
(
module
.
weight
)
if
module
.
bias
is
not
None
:
nn
.
init
.
constant_
(
module
.
bias
,
0
)
self
.
apply
(
_basic_init
)
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w
=
self
.
x_embedder
.
proj
.
weight
.
data
nn
.
init
.
xavier_uniform_
(
w
.
view
([
w
.
shape
[
0
],
-
1
]))
# Initialize timestep embedding MLP:
nn
.
init
.
normal_
(
self
.
t_embedder
.
mlp
[
0
].
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
t_embedder
.
mlp
[
2
].
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
t_block
[
1
].
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
t_block_y
[
1
].
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
t_block_temp
[
1
].
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
t_block_y_temp
[
1
].
weight
,
std
=
0.02
)
# Initialize caption embedding MLP:
nn
.
init
.
normal_
(
self
.
y_embedder
.
y_proj
.
fc1
.
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
y_embedder
.
y_proj
.
fc2
.
weight
,
std
=
0.02
)
# Zero-out adaLN modulation layers in PixArt blocks:
for
block
in
self
.
blocks
:
nn
.
init
.
constant_
(
block
.
cross_attn
.
proj
.
weight
,
0
)
nn
.
init
.
constant_
(
block
.
cross_attn
.
proj
.
bias
,
0
)
# Zero-out output layers:
nn
.
init
.
constant_
(
self
.
final_layer
.
linear
.
weight
,
0
)
nn
.
init
.
constant_
(
self
.
final_layer
.
linear
.
bias
,
0
)
@
MODELS
.
register_module
(
"STDiT_MMDiT_XL/2"
)
def
STDiT_MMDiT_XL_2
(
from_pretrained
=
None
,
**
kwargs
):
model
=
STDiT_MMDiT
(
depth
=
28
,
hidden_size
=
1152
,
patch_size
=
(
1
,
2
,
2
),
num_heads
=
16
,
**
kwargs
)
if
from_pretrained
is
not
None
:
load_checkpoint
(
model
,
from_pretrained
)
return
model
utils_data/opensora/models/stdit/stdit_mmdit_qk.py
0 → 100644
View file @
1f5da520
import
numpy
as
np
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
from
einops
import
rearrange
from
timm.models.layers
import
DropPath
from
timm.models.vision_transformer
import
Mlp
from
opensora.acceleration.checkpoint
import
auto_grad_checkpoint
from
opensora.acceleration.communications
import
gather_forward_split_backward
,
split_forward_gather_backward
from
opensora.acceleration.parallel_states
import
get_sequence_parallel_group
from
opensora.models.layers.blocks
import
(
Attention
,
Attention_QKNorm_RoPE
,
MaskedSelfAttention
,
CaptionEmbedder
,
MultiHeadCrossAttention
,
MaskedMultiHeadCrossAttention
,
PatchEmbed3D
,
SeqParallelAttention
,
SeqParallelMultiHeadCrossAttention
,
T2IFinalLayer
,
TimestepEmbedder
,
approx_gelu
,
get_1d_sincos_pos_embed
,
get_2d_sincos_pos_embed
,
get_layernorm
,
t2i_modulate
,
)
from
opensora.registry
import
MODELS
from
opensora.utils.ckpt_utils
import
load_checkpoint
import
ipdb
from
opensora.models.layers.timm_uvit
import
trunc_normal_
class
STDiTBlock
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
num_heads
,
d_s
=
None
,
d_t
=
None
,
mlp_ratio
=
4.0
,
drop_path
=
0.0
,
enable_flashattn
=
False
,
enable_layernorm_kernel
=
False
,
enable_sequence_parallelism
=
False
,
qk_norm
=
False
,
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
enable_flashattn
=
enable_flashattn
self
.
_enable_sequence_parallelism
=
enable_sequence_parallelism
if
enable_sequence_parallelism
:
self
.
attn_cls
=
SeqParallelAttention
self
.
mha_cls
=
SeqParallelMultiHeadCrossAttention
else
:
# here
self
.
self_masked_attn
=
MaskedSelfAttention
self
.
attn_cls
=
Attention_QKNorm_RoPE
self
.
mha_cls
=
MaskedMultiHeadCrossAttention
self
.
norm1
=
get_layernorm
(
hidden_size
,
eps
=
1e-6
,
affine
=
False
,
use_kernel
=
enable_layernorm_kernel
)
self
.
norm1_y
=
get_layernorm
(
hidden_size
,
eps
=
1e-6
,
affine
=
False
,
use_kernel
=
enable_layernorm_kernel
)
self
.
norm2
=
get_layernorm
(
hidden_size
,
eps
=
1e-6
,
affine
=
False
,
use_kernel
=
enable_layernorm_kernel
)
self
.
norm2_y
=
get_layernorm
(
hidden_size
,
eps
=
1e-6
,
affine
=
False
,
use_kernel
=
enable_layernorm_kernel
)
self
.
attn
=
self
.
self_masked_attn
(
hidden_size
,
num_heads
=
num_heads
,
qkv_bias
=
True
,
enable_flashattn
=
enable_flashattn
,
qk_norm
=
qk_norm
,
)
self
.
cross_attn
=
self
.
mha_cls
(
hidden_size
,
num_heads
)
self
.
norm3
=
get_layernorm
(
hidden_size
,
eps
=
1e-6
,
affine
=
False
,
use_kernel
=
enable_layernorm_kernel
)
self
.
norm3_y
=
get_layernorm
(
hidden_size
,
eps
=
1e-6
,
affine
=
False
,
use_kernel
=
enable_layernorm_kernel
)
self
.
mlp
=
Mlp
(
in_features
=
hidden_size
,
hidden_features
=
int
(
hidden_size
*
mlp_ratio
),
act_layer
=
approx_gelu
,
drop
=
0
)
self
.
mlp_y
=
Mlp
(
in_features
=
hidden_size
,
hidden_features
=
int
(
hidden_size
*
mlp_ratio
),
act_layer
=
approx_gelu
,
drop
=
0
)
self
.
drop_path
=
DropPath
(
drop_path
)
if
drop_path
>
0.0
else
nn
.
Identity
()
self
.
scale_shift_table
=
nn
.
Parameter
(
torch
.
randn
(
6
,
hidden_size
)
/
hidden_size
**
0.5
)
self
.
scale_shift_table_y
=
nn
.
Parameter
(
torch
.
randn
(
6
,
hidden_size
)
/
hidden_size
**
0.5
)
self
.
scale_shift_table_temp
=
nn
.
Parameter
(
torch
.
randn
(
3
,
hidden_size
)
/
hidden_size
**
0.5
)
self
.
scale_shift_table_y_temp
=
nn
.
Parameter
(
torch
.
randn
(
3
,
hidden_size
)
/
hidden_size
**
0.5
)
# temporal attention
self
.
d_s
=
d_s
self
.
d_t
=
d_t
if
self
.
_enable_sequence_parallelism
:
sp_size
=
dist
.
get_world_size
(
get_sequence_parallel_group
())
# make sure d_t is divisible by sp_size
assert
d_t
%
sp_size
==
0
self
.
d_t
=
d_t
//
sp_size
self
.
attn_temp
=
self
.
attn_cls
(
hidden_size
,
num_heads
=
num_heads
,
qkv_bias
=
True
,
enable_flashattn
=
self
.
enable_flashattn
,
qk_norm
=
qk_norm
,
)
def
forward
(
self
,
x
,
y
,
t
,
t_y
,
t_tmep
,
t_y_tmep
,
mask
=
None
,
tpe
=
None
):
B
,
N
,
C
=
x
.
shape
L
=
y
.
shape
[
2
]
# y: B T L C, mask: B T L
x
=
rearrange
(
x
,
"B (T S) C -> B T S C"
,
T
=
self
.
d_t
,
S
=
self
.
d_s
)
x_mask
=
torch
.
ones
(
x
.
shape
[:
3
],
device
=
x
.
device
,
dtype
=
x
.
dtype
)
# B T S
shift_msa
,
scale_msa
,
gate_msa
,
shift_mlp
,
scale_mlp
,
gate_mlp
=
(
self
.
scale_shift_table
[
None
]
+
t
.
reshape
(
B
,
6
,
-
1
)
).
chunk
(
6
,
dim
=
1
)
shift_msa_y
,
scale_msa_y
,
gate_msa_y
,
shift_mlp_y
,
scale_mlp_y
,
gate_mlp_y
=
(
self
.
scale_shift_table_y
[
None
]
+
t_y
.
reshape
(
B
,
6
,
-
1
)
).
chunk
(
6
,
dim
=
1
)
shift_msa_temp
,
scale_msa_temp
,
gate_msa_temp
=
(
self
.
scale_shift_table_temp
[
None
]
+
t_tmep
.
reshape
(
B
,
3
,
-
1
)
).
chunk
(
3
,
dim
=
1
)
shift_msa_y_temp
,
scale_msa_y_temp
,
gate_msa_y_temp
=
(
self
.
scale_shift_table_y_temp
[
None
]
+
t_y_tmep
.
reshape
(
B
,
3
,
-
1
)
).
chunk
(
3
,
dim
=
1
)
x
=
rearrange
(
x
,
"B T S C -> B (T S) C"
)
y
=
rearrange
(
y
,
"B T L C -> B (T L) C"
)
x_m
=
t2i_modulate
(
self
.
norm1
(
x
),
shift_msa
,
scale_msa
)
y_m
=
t2i_modulate
(
self
.
norm1_y
(
y
),
shift_msa_y
,
scale_msa_y
)
x_m
=
rearrange
(
x_m
,
"B (T S) C -> B T S C"
,
T
=
self
.
d_t
,
S
=
self
.
d_s
)
y_m
=
rearrange
(
y_m
,
"B (T L) C -> B T L C"
,
T
=
self
.
d_t
,
L
=
L
)
xy_m
=
torch
.
cat
([
x_m
,
y_m
],
dim
=
2
)
xy_mask
=
torch
.
cat
([
x_mask
,
mask
],
dim
=
2
)
xy_mask
=
rearrange
(
xy_mask
,
"B T N -> (B T) N"
)
# spatial branch
xy_s
=
rearrange
(
xy_m
,
"B T N C -> (B T) N C"
)
xy_s
=
self
.
attn
(
xy_s
,
xy_mask
)
xy_s
=
rearrange
(
xy_s
,
"(B T) N C -> B T N C"
,
B
=
B
,
T
=
self
.
d_t
)
x_s
=
xy_s
[:,
:,
:
self
.
d_s
,
:]
y_s
=
xy_s
[:,
:,
self
.
d_s
:,
:]
x_s
=
rearrange
(
x_s
,
"B T S C -> B (T S) C"
)
y_s
=
rearrange
(
y_s
,
"B T L C -> B (T L) C"
)
x
=
x
+
self
.
drop_path
(
gate_msa
*
x_s
)
y
=
y
+
self
.
drop_path
(
gate_msa_y
*
y_s
)
x_t
=
t2i_modulate
(
self
.
norm2
(
x
),
shift_msa_temp
,
scale_msa_temp
)
y_t
=
t2i_modulate
(
self
.
norm2_y
(
y
),
shift_msa_y_temp
,
scale_msa_y_temp
)
x_t
=
rearrange
(
x_t
,
"B (T S) C -> B T S C"
,
T
=
self
.
d_t
,
S
=
self
.
d_s
)
y_t
=
rearrange
(
y_t
,
"B (T L) C -> B T L C"
,
T
=
self
.
d_t
,
L
=
L
)
xy_t
=
torch
.
cat
([
x_t
,
y_t
],
dim
=
2
)
# temporal branch
xy_t
=
rearrange
(
xy_t
,
"B T N C -> (B N) T C"
)
if
tpe
is
not
None
:
xy_t
=
xy_t
+
tpe
xy_t
=
self
.
attn_temp
(
xy_t
)
xy_t
=
rearrange
(
xy_t
,
"(B N) T C -> B T N C"
,
B
=
B
,
N
=
self
.
d_s
+
L
)
x_t
=
xy_t
[:,
:,
:
self
.
d_s
,
:]
y_t
=
xy_t
[:,
:,
self
.
d_s
:,
:]
x_t
=
rearrange
(
x_t
,
"B T S C -> B (T S) C"
)
y_t
=
rearrange
(
y_t
,
"B T L C -> B (T L) C"
)
x
=
x
+
self
.
drop_path
(
gate_msa_temp
*
x_t
)
y
=
y
+
self
.
drop_path
(
gate_msa_y_temp
*
y_t
)
x
=
rearrange
(
x
,
"B (T S) C -> (B T) S C"
,
T
=
self
.
d_t
,
S
=
self
.
d_s
)
y
=
rearrange
(
y
,
"B (T L) C -> (B T) L C"
,
T
=
self
.
d_t
,
L
=
L
)
mask
=
rearrange
(
mask
,
"B T L -> (B T) L"
)
# cross attn
x
=
x
+
self
.
cross_attn
(
x
,
y
,
mask
)
x
=
rearrange
(
x
,
"(B T) S C -> B (T S) C"
,
B
=
B
,
T
=
self
.
d_t
)
y
=
rearrange
(
y
,
"(B T) L C -> B (T L) C"
,
B
=
B
,
T
=
self
.
d_t
)
# mlp
x
=
x
+
self
.
drop_path
(
gate_mlp
*
self
.
mlp
(
t2i_modulate
(
self
.
norm3
(
x
),
shift_mlp
,
scale_mlp
)))
y
=
y
+
self
.
drop_path
(
gate_mlp_y
*
self
.
mlp_y
(
t2i_modulate
(
self
.
norm3_y
(
y
),
shift_mlp_y
,
scale_mlp_y
)))
y
=
rearrange
(
y
,
"B (T L) C -> B T L C"
,
T
=
self
.
d_t
,
L
=
L
)
return
x
,
y
@
MODELS
.
register_module
()
class
STDiT_MMDiTQK
(
nn
.
Module
):
def
__init__
(
self
,
input_size
=
(
1
,
32
,
32
),
in_channels
=
4
,
patch_size
=
(
1
,
2
,
2
),
hidden_size
=
1152
,
depth
=
28
,
num_heads
=
16
,
mlp_ratio
=
4.0
,
class_dropout_prob
=
0.1
,
pred_sigma
=
True
,
drop_path
=
0.0
,
no_temporal_pos_emb
=
False
,
caption_channels
=
4096
,
model_max_length
=
120
,
dtype
=
torch
.
float32
,
space_scale
=
1.0
,
time_scale
=
1.0
,
freeze
=
None
,
enable_flashattn
=
False
,
enable_layernorm_kernel
=
False
,
enable_sequence_parallelism
=
False
,
qk_norm
=
True
,
):
super
().
__init__
()
self
.
pred_sigma
=
pred_sigma
self
.
in_channels
=
in_channels
self
.
out_channels
=
in_channels
*
2
if
pred_sigma
else
in_channels
self
.
hidden_size
=
hidden_size
self
.
patch_size
=
patch_size
self
.
input_size
=
input_size
num_patches
=
np
.
prod
([
input_size
[
i
]
//
patch_size
[
i
]
for
i
in
range
(
3
)])
self
.
num_patches
=
num_patches
self
.
num_temporal
=
input_size
[
0
]
//
patch_size
[
0
]
self
.
num_spatial
=
num_patches
//
self
.
num_temporal
self
.
num_heads
=
num_heads
self
.
dtype
=
dtype
self
.
no_temporal_pos_emb
=
no_temporal_pos_emb
self
.
depth
=
depth
self
.
mlp_ratio
=
mlp_ratio
self
.
enable_flashattn
=
enable_flashattn
self
.
enable_layernorm_kernel
=
enable_layernorm_kernel
self
.
space_scale
=
space_scale
self
.
time_scale
=
time_scale
self
.
register_buffer
(
"pos_embed"
,
self
.
get_spatial_pos_embed
())
self
.
register_buffer
(
"pos_embed_temporal"
,
self
.
get_temporal_pos_embed
())
self
.
x_embedder
=
PatchEmbed3D
(
patch_size
,
in_channels
,
hidden_size
)
self
.
t_embedder
=
TimestepEmbedder
(
hidden_size
)
self
.
t_block
=
nn
.
Sequential
(
nn
.
SiLU
(),
nn
.
Linear
(
hidden_size
,
6
*
hidden_size
,
bias
=
True
))
self
.
t_block_y
=
nn
.
Sequential
(
nn
.
SiLU
(),
nn
.
Linear
(
hidden_size
,
6
*
hidden_size
,
bias
=
True
))
self
.
t_block_temp
=
nn
.
Sequential
(
nn
.
SiLU
(),
nn
.
Linear
(
hidden_size
,
3
*
hidden_size
,
bias
=
True
))
self
.
t_block_y_temp
=
nn
.
Sequential
(
nn
.
SiLU
(),
nn
.
Linear
(
hidden_size
,
3
*
hidden_size
,
bias
=
True
))
self
.
y_embedder
=
CaptionEmbedder
(
in_channels
=
caption_channels
,
hidden_size
=
hidden_size
,
uncond_prob
=
class_dropout_prob
,
act_layer
=
approx_gelu
,
token_num
=
model_max_length
,
)
drop_path
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path
,
depth
)]
self
.
blocks
=
nn
.
ModuleList
(
[
STDiTBlock
(
self
.
hidden_size
,
self
.
num_heads
,
mlp_ratio
=
self
.
mlp_ratio
,
drop_path
=
drop_path
[
i
],
enable_flashattn
=
self
.
enable_flashattn
,
enable_layernorm_kernel
=
self
.
enable_layernorm_kernel
,
enable_sequence_parallelism
=
enable_sequence_parallelism
,
d_t
=
self
.
num_temporal
,
d_s
=
self
.
num_spatial
,
qk_norm
=
qk_norm
,
)
for
i
in
range
(
self
.
depth
)
]
)
self
.
final_layer
=
T2IFinalLayer
(
hidden_size
,
np
.
prod
(
self
.
patch_size
),
self
.
out_channels
)
# init model
self
.
initialize_weights
()
self
.
initialize_temporal
()
if
freeze
is
not
None
:
assert
freeze
in
[
"not_temporal"
,
"text"
]
if
freeze
==
"not_temporal"
:
self
.
freeze_not_temporal
()
elif
freeze
==
"text"
:
self
.
freeze_text
()
# sequence parallel related configs
self
.
enable_sequence_parallelism
=
enable_sequence_parallelism
if
enable_sequence_parallelism
:
self
.
sp_rank
=
dist
.
get_rank
(
get_sequence_parallel_group
())
else
:
self
.
sp_rank
=
None
def
forward
(
self
,
x
,
timestep
,
y
,
mask
=
None
):
"""
Forward pass of STDiT.
Args:
x (torch.Tensor): latent representation of video; of shape [B, C, T, H, W]
timestep (torch.Tensor): diffusion time steps; of shape [B]
y (torch.Tensor): representation of prompts; of shape [B, 1, N_token, C]
mask (torch.Tensor): mask for selecting prompt tokens; of shape [B, N_token]
Returns:
x (torch.Tensor): output latent representation; of shape [B, C, T, H, W]
"""
x
=
x
.
to
(
self
.
dtype
)
timestep
=
timestep
.
to
(
self
.
dtype
)
y
=
y
.
to
(
self
.
dtype
)
# embedding
x
=
self
.
x_embedder
(
x
)
# [B, (THW), C]
x
=
rearrange
(
x
,
"B (T S) C -> B T S C"
,
T
=
self
.
num_temporal
,
S
=
self
.
num_spatial
)
x
=
x
+
self
.
pos_embed
x
=
rearrange
(
x
,
"B T S C -> B (T S) C"
)
# shard over the sequence dim if sp is enabled
if
self
.
enable_sequence_parallelism
:
x
=
split_forward_gather_backward
(
x
,
get_sequence_parallel_group
(),
dim
=
1
,
grad_scale
=
"down"
)
t
=
self
.
t_embedder
(
timestep
,
dtype
=
x
.
dtype
)
# [B, C]
t0
=
self
.
t_block
(
t
)
# [B, C]
t_y
=
self
.
t_block_y
(
t
)
t0_tmep
=
self
.
t_block_temp
(
t
)
# [B, C]
t_y_tmep
=
self
.
t_block_y_temp
(
t
)
y
=
self
.
y_embedder
(
y
,
self
.
training
)
# [B, 1, N_token, C]
if
mask
is
not
None
:
if
mask
.
shape
[
0
]
!=
y
.
shape
[
0
]:
mask
=
mask
.
repeat
(
y
.
shape
[
0
]
//
mask
.
shape
[
0
],
1
)
y
=
y
.
repeat
(
1
,
self
.
num_temporal
,
1
,
1
)
# B T L C
mask
=
mask
.
unsqueeze
(
1
).
repeat
(
1
,
self
.
num_temporal
,
1
)
# B T L
# mask = mask.squeeze(1).squeeze(1)
# y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
# y_lens = mask.sum(dim=1).tolist()
else
:
y_lens
=
[
y
.
shape
[
2
]]
*
y
.
shape
[
0
]
y
=
y
.
squeeze
(
1
).
view
(
1
,
-
1
,
x
.
shape
[
-
1
])
# blocks
for
i
,
block
in
enumerate
(
self
.
blocks
):
if
i
==
0
:
if
self
.
enable_sequence_parallelism
:
tpe
=
torch
.
chunk
(
self
.
pos_embed_temporal
,
dist
.
get_world_size
(
get_sequence_parallel_group
()),
dim
=
1
)[
self
.
sp_rank
].
contiguous
()
else
:
tpe
=
self
.
pos_embed_temporal
else
:
tpe
=
None
x
,
y
=
auto_grad_checkpoint
(
block
,
x
,
y
,
t0
,
t_y
,
t0_tmep
,
t_y_tmep
,
mask
,
tpe
)
if
self
.
enable_sequence_parallelism
:
x
=
gather_forward_split_backward
(
x
,
get_sequence_parallel_group
(),
dim
=
1
,
grad_scale
=
"up"
)
# x.shape: [B, N, C]
# final process
x
=
self
.
final_layer
(
x
,
t
)
# [B, N, C=T_p * H_p * W_p * C_out]
x
=
self
.
unpatchify
(
x
)
# [B, C_out, T, H, W]
# cast to float32 for better accuracy
x
=
x
.
to
(
torch
.
float32
)
return
x
def
unpatchify
(
self
,
x
):
"""
Args:
x (torch.Tensor): of shape [B, N, C]
Return:
x (torch.Tensor): of shape [B, C_out, T, H, W]
"""
N_t
,
N_h
,
N_w
=
[
self
.
input_size
[
i
]
//
self
.
patch_size
[
i
]
for
i
in
range
(
3
)]
T_p
,
H_p
,
W_p
=
self
.
patch_size
x
=
rearrange
(
x
,
"B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)"
,
N_t
=
N_t
,
N_h
=
N_h
,
N_w
=
N_w
,
T_p
=
T_p
,
H_p
=
H_p
,
W_p
=
W_p
,
C_out
=
self
.
out_channels
,
)
return
x
def
unpatchify_old
(
self
,
x
):
c
=
self
.
out_channels
t
,
h
,
w
=
[
self
.
input_size
[
i
]
//
self
.
patch_size
[
i
]
for
i
in
range
(
3
)]
pt
,
ph
,
pw
=
self
.
patch_size
x
=
x
.
reshape
(
shape
=
(
x
.
shape
[
0
],
t
,
h
,
w
,
pt
,
ph
,
pw
,
c
))
x
=
rearrange
(
x
,
"n t h w r p q c -> n c t r h p w q"
)
imgs
=
x
.
reshape
(
shape
=
(
x
.
shape
[
0
],
c
,
t
*
pt
,
h
*
ph
,
w
*
pw
))
return
imgs
def
get_spatial_pos_embed
(
self
,
grid_size
=
None
):
if
grid_size
is
None
:
grid_size
=
self
.
input_size
[
1
:]
pos_embed
=
get_2d_sincos_pos_embed
(
self
.
hidden_size
,
(
grid_size
[
0
]
//
self
.
patch_size
[
1
],
grid_size
[
1
]
//
self
.
patch_size
[
2
]),
scale
=
self
.
space_scale
,
)
pos_embed
=
torch
.
from_numpy
(
pos_embed
).
float
().
unsqueeze
(
0
).
requires_grad_
(
False
)
return
pos_embed
def
get_temporal_pos_embed
(
self
):
pos_embed
=
get_1d_sincos_pos_embed
(
self
.
hidden_size
,
self
.
input_size
[
0
]
//
self
.
patch_size
[
0
],
scale
=
self
.
time_scale
,
)
pos_embed
=
torch
.
from_numpy
(
pos_embed
).
float
().
unsqueeze
(
0
).
requires_grad_
(
False
)
return
pos_embed
def
freeze_not_temporal
(
self
):
for
n
,
p
in
self
.
named_parameters
():
if
"attn_temp"
not
in
n
:
p
.
requires_grad
=
False
def
freeze_text
(
self
):
for
n
,
p
in
self
.
named_parameters
():
if
"cross_attn"
in
n
:
p
.
requires_grad
=
False
def
initialize_temporal
(
self
):
for
block
in
self
.
blocks
:
nn
.
init
.
constant_
(
block
.
attn_temp
.
proj
.
weight
,
0
)
nn
.
init
.
constant_
(
block
.
attn_temp
.
proj
.
bias
,
0
)
def
initialize_weights
(
self
):
# Initialize transformer layers:
def
_basic_init
(
module
):
if
isinstance
(
module
,
nn
.
Linear
):
torch
.
nn
.
init
.
xavier_uniform_
(
module
.
weight
)
if
module
.
bias
is
not
None
:
nn
.
init
.
constant_
(
module
.
bias
,
0
)
self
.
apply
(
_basic_init
)
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w
=
self
.
x_embedder
.
proj
.
weight
.
data
nn
.
init
.
xavier_uniform_
(
w
.
view
([
w
.
shape
[
0
],
-
1
]))
# Initialize timestep embedding MLP:
nn
.
init
.
normal_
(
self
.
t_embedder
.
mlp
[
0
].
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
t_embedder
.
mlp
[
2
].
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
t_block
[
1
].
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
t_block_y
[
1
].
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
t_block_temp
[
1
].
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
t_block_y_temp
[
1
].
weight
,
std
=
0.02
)
# Initialize caption embedding MLP:
nn
.
init
.
normal_
(
self
.
y_embedder
.
y_proj
.
fc1
.
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
y_embedder
.
y_proj
.
fc2
.
weight
,
std
=
0.02
)
# Zero-out adaLN modulation layers in PixArt blocks:
for
block
in
self
.
blocks
:
nn
.
init
.
constant_
(
block
.
cross_attn
.
proj
.
weight
,
0
)
nn
.
init
.
constant_
(
block
.
cross_attn
.
proj
.
bias
,
0
)
# Zero-out output layers:
nn
.
init
.
constant_
(
self
.
final_layer
.
linear
.
weight
,
0
)
nn
.
init
.
constant_
(
self
.
final_layer
.
linear
.
bias
,
0
)
@
MODELS
.
register_module
(
"STDiT_MMDiTQK_XL/2"
)
def
STDiT_MMDiTQK_XL_2
(
from_pretrained
=
None
,
**
kwargs
):
model
=
STDiT_MMDiTQK
(
depth
=
28
,
hidden_size
=
1152
,
patch_size
=
(
1
,
2
,
2
),
num_heads
=
16
,
**
kwargs
)
if
from_pretrained
is
not
None
:
load_checkpoint
(
model
,
from_pretrained
)
return
model
Prev
1
…
8
9
10
11
12
13
14
15
16
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment