Commit 1f5da520 authored by yangzhong's avatar yangzhong
Browse files

git init

parents
Pipeline #3144 failed with stages
in 0 seconds
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from einops import rearrange
from timm.models.layers import DropPath
from timm.models.vision_transformer import Mlp
from opensora.acceleration.checkpoint import auto_grad_checkpoint
from opensora.acceleration.communications import gather_forward_split_backward, split_forward_gather_backward
from opensora.acceleration.parallel_states import get_sequence_parallel_group
from opensora.models.layers.blocks import (
Attention,
Attention_QKNorm_RoPE,
CaptionEmbedder,
MultiHeadCrossAttention,
PatchEmbed3D,
SeqParallelAttention,
SeqParallelMultiHeadCrossAttention,
T2IFinalLayer,
TimestepEmbedder,
approx_gelu,
get_1d_sincos_pos_embed,
get_2d_sincos_pos_embed,
get_layernorm,
t2i_modulate,
)
from opensora.registry import MODELS
from opensora.utils.ckpt_utils import load_checkpoint
# import ipdb
from rotary_embedding_torch import RotaryEmbedding
class STDiTBlock(nn.Module):
def __init__(
self,
hidden_size,
num_heads,
d_s=None,
d_t=None,
mlp_ratio=4.0,
drop_path=0.0,
enable_flashattn=False,
enable_layernorm_kernel=False,
enable_sequence_parallelism=False,
rope=None,
qk_norm=False,
):
super().__init__()
self.hidden_size = hidden_size
self.enable_flashattn = enable_flashattn
self._enable_sequence_parallelism = enable_sequence_parallelism
if enable_sequence_parallelism:
self.attn_cls = SeqParallelAttention
self.mha_cls = SeqParallelMultiHeadCrossAttention
else: # here
self.attn_cls = Attention_QKNorm_RoPE
self.mha_cls = MultiHeadCrossAttention
self.norm1 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)
self.attn = self.attn_cls(
hidden_size,
num_heads=num_heads,
qkv_bias=True,
enable_flashattn=enable_flashattn,
qk_norm=qk_norm,
)
self.cross_attn = self.mha_cls(hidden_size, num_heads)
self.norm2 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)
self.mlp = Mlp(
in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5)
self.scale_shift_table_temporal = nn.Parameter(torch.randn(3, hidden_size) / hidden_size**0.5) # new
# temporal attention
self.d_s = d_s
self.d_t = d_t
if self._enable_sequence_parallelism:
sp_size = dist.get_world_size(get_sequence_parallel_group())
# make sure d_t is divisible by sp_size
assert d_t % sp_size == 0
self.d_t = d_t // sp_size
self.norm_temp = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel) # new
self.attn_temp = self.attn_cls(
hidden_size,
num_heads=num_heads,
qkv_bias=True,
enable_flashattn=self.enable_flashattn,
rope=rope,
qk_norm=qk_norm,
)
def forward(self, x, y, t, t_temp, mask=None, tpe=None):
B, N, C = x.shape
#ipdb.set_trace()
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.scale_shift_table[None] + t.reshape(B, 6, -1)
).chunk(6, dim=1)
shift_tmp, scale_tmp, gate_tmp = (
self.scale_shift_table_temporal[None] + t_temp.reshape(B, 3, -1)
).chunk(3, dim=1)
x_m = t2i_modulate(self.norm1(x), shift_msa, scale_msa)
# spatial branch
x_s = rearrange(x_m, "B (T S) C -> (B T) S C", T=self.d_t, S=self.d_s)
x_s = self.attn(x_s)
x_s = rearrange(x_s, "(B T) S C -> B (T S) C", T=self.d_t, S=self.d_s)
x = x + self.drop_path(gate_msa * x_s)
# modulate
#ipdb.set_trace()
x_m = t2i_modulate(self.norm_temp(x), shift_tmp, scale_tmp)
# temporal branch
x_t = rearrange(x_m, "B (T S) C -> (B S) T C", T=self.d_t, S=self.d_s)
if tpe is not None:
x_t = x_t + tpe
x_t = self.attn_temp(x_t)
x_t = rearrange(x_t, "(B S) T C -> B (T S) C", T=self.d_t, S=self.d_s)
x = x + self.drop_path(gate_tmp * x_t)
# cross attn
x = x + self.cross_attn(x, y, mask)
# mlp
x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
return x
@MODELS.register_module()
class STDiT_QKNorm_RoPE(nn.Module):
def __init__(
self,
input_size=(1, 32, 32),
in_channels=4,
patch_size=(1, 2, 2),
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4.0,
class_dropout_prob=0.1,
pred_sigma=True,
drop_path=0.0,
no_temporal_pos_emb=False,
caption_channels=4096,
model_max_length=120,
dtype=torch.float32,
space_scale=1.0,
time_scale=1.0,
freeze=None,
enable_flashattn=False,
enable_layernorm_kernel=False,
enable_sequence_parallelism=False,
qk_norm=False,
rope=False,
):
super().__init__()
self.pred_sigma = pred_sigma
self.in_channels = in_channels
self.out_channels = in_channels * 2 if pred_sigma else in_channels
self.hidden_size = hidden_size
self.patch_size = patch_size
self.input_size = input_size
num_patches = np.prod([input_size[i] // patch_size[i] for i in range(3)])
self.num_patches = num_patches
self.num_temporal = input_size[0] // patch_size[0]
self.num_spatial = num_patches // self.num_temporal
self.num_heads = num_heads
self.dtype = dtype
self.no_temporal_pos_emb = no_temporal_pos_emb
self.depth = depth
self.mlp_ratio = mlp_ratio
self.enable_flashattn = enable_flashattn
self.enable_layernorm_kernel = enable_layernorm_kernel
self.space_scale = space_scale
self.time_scale = time_scale
self.register_buffer("pos_embed", self.get_spatial_pos_embed())
self.register_buffer("pos_embed_temporal", self.get_temporal_pos_embed())
self.x_embedder = PatchEmbed3D(patch_size, in_channels, hidden_size)
self.t_embedder = TimestepEmbedder(hidden_size)
self.t_block = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
self.t_block_temp = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 3 * hidden_size, bias=True))
self.y_embedder = CaptionEmbedder(
in_channels=caption_channels,
hidden_size=hidden_size,
uncond_prob=class_dropout_prob,
act_layer=approx_gelu,
token_num=model_max_length,
)
drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)]
#ipdb.set_trace()
if rope:
RoPE = RotaryEmbedding(dim=self.hidden_size // self.num_heads)
self.rope = RoPE.rotate_queries_or_keys
else:
self.rope = None
#ipdb.set_trace()
self.blocks = nn.ModuleList(
[
STDiTBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=self.mlp_ratio,
drop_path=drop_path[i],
enable_flashattn=self.enable_flashattn,
enable_layernorm_kernel=self.enable_layernorm_kernel,
enable_sequence_parallelism=enable_sequence_parallelism,
d_t=self.num_temporal,
d_s=self.num_spatial,
rope=self.rope,
qk_norm=qk_norm,
)
for i in range(self.depth)
]
)
self.final_layer = T2IFinalLayer(hidden_size, np.prod(self.patch_size), self.out_channels)
# init model
self.initialize_weights()
self.initialize_temporal()
if freeze is not None:
assert freeze in ["not_temporal", "text"]
if freeze == "not_temporal":
self.freeze_not_temporal()
elif freeze == "text":
self.freeze_text()
# sequence parallel related configs
self.enable_sequence_parallelism = enable_sequence_parallelism
if enable_sequence_parallelism:
self.sp_rank = dist.get_rank(get_sequence_parallel_group())
else:
self.sp_rank = None
def forward(self, x, timestep, y, mask=None):
"""
Forward pass of STDiT.
Args:
x (torch.Tensor): latent representation of video; of shape [B, C, T, H, W]
timestep (torch.Tensor): diffusion time steps; of shape [B]
y (torch.Tensor): representation of prompts; of shape [B, 1, N_token, C]
mask (torch.Tensor): mask for selecting prompt tokens; of shape [B, N_token]
Returns:
x (torch.Tensor): output latent representation; of shape [B, C, T, H, W]
"""
x = x.to(self.dtype)
timestep = timestep.to(self.dtype)
y = y.to(self.dtype)
# embedding
x = self.x_embedder(x) # [B, N, C]
x = rearrange(x, "B (T S) C -> B T S C", T=self.num_temporal, S=self.num_spatial)
x = x + self.pos_embed
x = rearrange(x, "B T S C -> B (T S) C")
# shard over the sequence dim if sp is enabled
if self.enable_sequence_parallelism:
x = split_forward_gather_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="down")
t = self.t_embedder(timestep, dtype=x.dtype) # [B, C]
t0 = self.t_block(t) # [B, C]
t0_temp = self.t_block_temp(t) # [B, C]
y = self.y_embedder(y, self.training) # [B, 1, N_token, C]
if mask is not None:
if mask.shape[0] != y.shape[0]:
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
mask = mask.squeeze(1).squeeze(1)
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
else:
y_lens = [y.shape[2]] * y.shape[0]
y = y.squeeze(1).view(1, -1, x.shape[-1])
# blocks
for i, block in enumerate(self.blocks):
if i == 0:
if self.enable_sequence_parallelism:
tpe = torch.chunk(
self.pos_embed_temporal, dist.get_world_size(get_sequence_parallel_group()), dim=1
)[self.sp_rank].contiguous()
else:
tpe = self.pos_embed_temporal
else:
tpe = None
x = auto_grad_checkpoint(block, x, y, t0, t0_temp, y_lens, tpe)
if self.enable_sequence_parallelism:
x = gather_forward_split_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="up")
# x.shape: [B, N, C]
# final process
x = self.final_layer(x, t) # [B, N, C=T_p * H_p * W_p * C_out]
x = self.unpatchify(x) # [B, C_out, T, H, W]
# cast to float32 for better accuracy
x = x.to(torch.float32)
return x
def unpatchify(self, x):
"""
Args:
x (torch.Tensor): of shape [B, N, C]
Return:
x (torch.Tensor): of shape [B, C_out, T, H, W]
"""
N_t, N_h, N_w = [self.input_size[i] // self.patch_size[i] for i in range(3)]
T_p, H_p, W_p = self.patch_size
x = rearrange(
x,
"B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)",
N_t=N_t,
N_h=N_h,
N_w=N_w,
T_p=T_p,
H_p=H_p,
W_p=W_p,
C_out=self.out_channels,
)
return x
def unpatchify_old(self, x):
c = self.out_channels
t, h, w = [self.input_size[i] // self.patch_size[i] for i in range(3)]
pt, ph, pw = self.patch_size
x = x.reshape(shape=(x.shape[0], t, h, w, pt, ph, pw, c))
x = rearrange(x, "n t h w r p q c -> n c t r h p w q")
imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
return imgs
def get_spatial_pos_embed(self, grid_size=None):
if grid_size is None:
grid_size = self.input_size[1:]
pos_embed = get_2d_sincos_pos_embed(
self.hidden_size,
(grid_size[0] // self.patch_size[1], grid_size[1] // self.patch_size[2]),
scale=self.space_scale,
)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)
return pos_embed
def get_temporal_pos_embed(self):
pos_embed = get_1d_sincos_pos_embed(
self.hidden_size,
self.input_size[0] // self.patch_size[0],
scale=self.time_scale,
)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)
return pos_embed
def freeze_not_temporal(self):
for n, p in self.named_parameters():
if "attn_temp" not in n:
p.requires_grad = False
def freeze_text(self):
for n, p in self.named_parameters():
if "cross_attn" in n:
p.requires_grad = False
def initialize_temporal(self):
for block in self.blocks:
nn.init.constant_(block.attn_temp.proj.weight, 0)
nn.init.constant_(block.attn_temp.proj.bias, 0)
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
nn.init.normal_(self.t_block[1].weight, std=0.02)
nn.init.normal_(self.t_block_temp[1].weight, std=0.02)
# Initialize caption embedding MLP:
nn.init.normal_(self.y_embedder.y_proj.fc1.weight, std=0.02)
nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02)
# Zero-out adaLN modulation layers in PixArt blocks:
for block in self.blocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
@MODELS.register_module("STDiT_QKNorm_RoPE_XL/2")
def STDiT_QKNorm_RoPE_XL_2(from_pretrained=None, **kwargs):
#ipdb.set_trace()
model = STDiT_QKNorm_RoPE(depth=28, hidden_size=1152, patch_size=(1, 2, 2), num_heads=16, **kwargs)
if from_pretrained is not None:
load_checkpoint(model, from_pretrained)
return model
from .classes import ClassEncoder
from .clip import ClipEncoder
from .t5 import T5Encoder
import torch
from opensora.registry import MODELS
@MODELS.register_module("classes")
class ClassEncoder:
def __init__(self, num_classes, model_max_length=None, device="cuda", dtype=torch.float):
self.num_classes = num_classes
self.y_embedder = None
self.model_max_length = model_max_length
self.output_dim = None
self.device = device
def encode(self, text):
return dict(y=torch.tensor([int(t) for t in text]).to(self.device))
def null(self, n):
return torch.tensor([self.num_classes] * n).to(self.device)
# Copyright 2024 Vchitect/Latte
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.# Modified from Latte
#
# This file is adapted from the Latte project.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# Latte: https://github.com/Vchitect/Latte
# DiT: https://github.com/facebookresearch/DiT/tree/main
# --------------------------------------------------------
import torch
import torch.nn as nn
import transformers
from transformers import CLIPTextModel, CLIPTokenizer
from opensora.registry import MODELS
transformers.logging.set_verbosity_error()
class AbstractEncoder(nn.Module):
def __init__(self):
super().__init__()
def encode(self, *args, **kwargs):
raise NotImplementedError
class FrozenCLIPEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
def __init__(self, path="openai/clip-vit-huge-patch14", device="cuda", max_length=77):
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(path)
self.transformer = CLIPTextModel.from_pretrained(path)
self.device = device
self.max_length = max_length
self._freeze()
def _freeze(self):
self.transformer = self.transformer.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
tokens = batch_encoding["input_ids"].to(self.device)
outputs = self.transformer(input_ids=tokens)
z = outputs.last_hidden_state
pooled_z = outputs.pooler_output
return z, pooled_z
def encode(self, text):
return self(text)
@MODELS.register_module("clip")
class ClipEncoder:
"""
Embeds text prompt into vector representations. Also handles text dropout for classifier-free guidance.
"""
def __init__(
self,
from_pretrained,
model_max_length=77,
device="cuda",
dtype=torch.float,
):
super().__init__()
assert from_pretrained is not None, "Please specify the path to the T5 model"
self.text_encoder = FrozenCLIPEmbedder(path=from_pretrained, max_length=model_max_length).to(device, dtype)
self.y_embedder = None
self.model_max_length = model_max_length
self.output_dim = self.text_encoder.transformer.config.hidden_size
def encode(self, text):
_, pooled_embeddings = self.text_encoder.encode(text)
y = pooled_embeddings.unsqueeze(1).unsqueeze(1)
return dict(y=y)
def null(self, n):
null_y = self.y_embedder.y_embedding[None].repeat(n, 1, 1)[:, None]
return null_y
def to(self, dtype):
self.text_encoder = self.text_encoder.to(dtype)
return self
# Adapted from PixArt
#
# Copyright (C) 2023 PixArt-alpha/PixArt-alpha
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# PixArt: https://github.com/PixArt-alpha/PixArt-alpha
# T5: https://github.com/google-research/text-to-text-transfer-transformer
# --------------------------------------------------------
import html
import re
import ftfy
import torch
from transformers import AutoTokenizer, T5EncoderModel
from opensora.registry import MODELS
class T5Embedder:
def __init__(
self,
device,
from_pretrained=None,
*,
cache_dir=None,
hf_token=None,
use_text_preprocessing=True,
t5_model_kwargs=None,
torch_dtype=None,
use_offload_folder=None,
model_max_length=120,
local_files_only=False,
):
self.device = torch.device(device)
self.torch_dtype = torch_dtype or torch.bfloat16
self.cache_dir = cache_dir
if t5_model_kwargs is None:
t5_model_kwargs = {
"low_cpu_mem_usage": True,
"torch_dtype": self.torch_dtype,
}
if use_offload_folder is not None:
t5_model_kwargs["offload_folder"] = use_offload_folder
t5_model_kwargs["device_map"] = {
"shared": self.device,
"encoder.embed_tokens": self.device,
"encoder.block.0": self.device,
"encoder.block.1": self.device,
"encoder.block.2": self.device,
"encoder.block.3": self.device,
"encoder.block.4": self.device,
"encoder.block.5": self.device,
"encoder.block.6": self.device,
"encoder.block.7": self.device,
"encoder.block.8": self.device,
"encoder.block.9": self.device,
"encoder.block.10": self.device,
"encoder.block.11": self.device,
"encoder.block.12": "disk",
"encoder.block.13": "disk",
"encoder.block.14": "disk",
"encoder.block.15": "disk",
"encoder.block.16": "disk",
"encoder.block.17": "disk",
"encoder.block.18": "disk",
"encoder.block.19": "disk",
"encoder.block.20": "disk",
"encoder.block.21": "disk",
"encoder.block.22": "disk",
"encoder.block.23": "disk",
"encoder.final_layer_norm": "disk",
"encoder.dropout": "disk",
}
else:
t5_model_kwargs["device_map"] = {
"shared": self.device,
"encoder": self.device,
}
self.use_text_preprocessing = use_text_preprocessing
self.hf_token = hf_token
self.tokenizer = AutoTokenizer.from_pretrained(
from_pretrained,
cache_dir=cache_dir,
local_files_only=local_files_only,
)
self.model = T5EncoderModel.from_pretrained(
from_pretrained,
cache_dir=cache_dir,
local_files_only=local_files_only,
**t5_model_kwargs,
).eval()
self.model_max_length = model_max_length
def get_text_embeddings(self, texts):
text_tokens_and_mask = self.tokenizer(
texts,
max_length=self.model_max_length,
padding="max_length",
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors="pt",
)
input_ids = text_tokens_and_mask["input_ids"].to(self.device)
attention_mask = text_tokens_and_mask["attention_mask"].to(self.device)
with torch.no_grad():
text_encoder_embs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
)["last_hidden_state"].detach()
return text_encoder_embs, attention_mask
@MODELS.register_module("t5")
class T5Encoder:
def __init__(
self,
from_pretrained=None,
model_max_length=120,
device="cuda",
dtype=torch.float,
cache_dir=None,
shardformer=False,
local_files_only=False,
):
assert from_pretrained is not None, "Please specify the path to the T5 model"
self.t5 = T5Embedder(
device=device,
torch_dtype=dtype,
from_pretrained=from_pretrained,
cache_dir=cache_dir,
model_max_length=model_max_length,
local_files_only=local_files_only,
)
self.t5.model.to(dtype=dtype)
self.y_embedder = None
self.model_max_length = model_max_length
self.output_dim = self.t5.model.config.d_model
self.dtype = dtype
if shardformer:
self.shardformer_t5()
def shardformer_t5(self):
from colossalai.shardformer import ShardConfig, ShardFormer
from opensora.acceleration.shardformer.policy.t5_encoder import T5EncoderPolicy
from opensora.utils.misc import requires_grad
shard_config = ShardConfig(
tensor_parallel_process_group=None,
pipeline_stage_manager=None,
enable_tensor_parallelism=False,
enable_fused_normalization=False,
enable_flash_attention=False,
enable_jit_fused=True,
enable_sequence_parallelism=False,
enable_sequence_overlap=False,
)
shard_former = ShardFormer(shard_config=shard_config)
optim_model, _ = shard_former.optimize(self.t5.model, policy=T5EncoderPolicy())
self.t5.model = optim_model.to(self.dtype)
# ensure the weights are frozen
requires_grad(self.t5.model, False)
def encode(self, text):
caption_embs, emb_masks = self.t5.get_text_embeddings(text)
caption_embs = caption_embs[:, None]
return dict(y=caption_embs, mask=emb_masks)
def null(self, n):
null_y = self.y_embedder.y_embedding[None].repeat(n, 1, 1)[:, None]
return null_y
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
BAD_PUNCT_REGEX = re.compile(
r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
) # noqa
def clean_caption(caption):
import urllib.parse as ul
from bs4 import BeautifulSoup
caption = str(caption)
caption = ul.unquote_plus(caption)
caption = caption.strip().lower()
caption = re.sub("<person>", "person", caption)
# urls:
caption = re.sub(
r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
"",
caption,
) # regex for urls
caption = re.sub(
r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
"",
caption,
) # regex for urls
# html:
caption = BeautifulSoup(caption, features="html.parser").text
# @<nickname>
caption = re.sub(r"@[\w\d]+\b", "", caption)
# 31C0—31EF CJK Strokes
# 31F0—31FF Katakana Phonetic Extensions
# 3200—32FF Enclosed CJK Letters and Months
# 3300—33FF CJK Compatibility
# 3400—4DBF CJK Unified Ideographs Extension A
# 4DC0—4DFF Yijing Hexagram Symbols
# 4E00—9FFF CJK Unified Ideographs
caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
#######################################################
# все виды тире / all types of dash --> "-"
caption = re.sub(
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
"-",
caption,
)
# кавычки к одному стандарту
caption = re.sub(r"[`´«»“”¨]", '"', caption)
caption = re.sub(r"[‘’]", "'", caption)
# &quot;
caption = re.sub(r"&quot;?", "", caption)
# &amp
caption = re.sub(r"&amp", "", caption)
# ip adresses:
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
# article ids:
caption = re.sub(r"\d:\d\d\s+$", "", caption)
# \n
caption = re.sub(r"\\n", " ", caption)
# "#123"
caption = re.sub(r"#\d{1,3}\b", "", caption)
# "#12345.."
caption = re.sub(r"#\d{5,}\b", "", caption)
# "123456.."
caption = re.sub(r"\b\d{6,}\b", "", caption)
# filenames:
caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
#
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
caption = re.sub(BAD_PUNCT_REGEX, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
# this-is-my-cute-cat / this_is_my_cute_cat
regex2 = re.compile(r"(?:\-|\_)")
if len(re.findall(regex2, caption)) > 3:
caption = re.sub(regex2, " ", caption)
caption = basic_clean(caption)
caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
caption = re.sub(r"\bpage\s+\d+\b", "", caption)
caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
caption = re.sub(r"\b\s+\:\s+", r": ", caption)
caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
caption = re.sub(r"\s+", " ", caption)
caption.strip()
caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
caption = re.sub(r"^\.\S+$", "", caption)
return caption.strip()
def text_preprocessing(text, use_text_preprocessing: bool = True):
if use_text_preprocessing:
# The exact text cleaning as was in the training stage:
text = clean_caption(text)
text = clean_caption(text)
return text
else:
return text.lower().strip()
from .vae import VideoAutoencoderKL, VideoAutoencoderKLTemporalDecoder
import torch
import torch.nn as nn
from diffusers.models import AutoencoderKL, AutoencoderKLTemporalDecoder
from einops import rearrange
from opensora.registry import MODELS
@MODELS.register_module()
class VideoAutoencoderKL(nn.Module):
def __init__(self, from_pretrained=None, micro_batch_size=None):
super().__init__()
self.module = AutoencoderKL.from_pretrained(from_pretrained)
self.out_channels = self.module.config.latent_channels
self.patch_size = (1, 8, 8)
self.micro_batch_size = micro_batch_size
def encode(self, x):
# x: (B, C, T, H, W)
B = x.shape[0]
x = rearrange(x, "B C T H W -> (B T) C H W")
if self.micro_batch_size is None:
x = self.module.encode(x).latent_dist.sample().mul_(0.18215)
else:
bs = self.micro_batch_size
x_out = []
for i in range(0, x.shape[0], bs):
x_bs = x[i : i + bs]
x_bs = self.module.encode(x_bs).latent_dist.sample().mul_(0.18215)
x_out.append(x_bs)
x = torch.cat(x_out, dim=0)
x = rearrange(x, "(B T) C H W -> B C T H W", B=B)
return x
def decode(self, x):
# x: (B, C, T, H, W)
B = x.shape[0]
x = rearrange(x, "B C T H W -> (B T) C H W")
if self.micro_batch_size is None:
x = self.module.decode(x / 0.18215).sample
else:
bs = self.micro_batch_size
x_out = []
for i in range(0, x.shape[0], bs):
x_bs = x[i : i + bs]
x_bs = self.module.decode(x_bs / 0.18215).sample
x_out.append(x_bs)
x = torch.cat(x_out, dim=0)
x = rearrange(x, "(B T) C H W -> B C T H W", B=B)
return x
def get_latent_size(self, input_size):
for i in range(3):
assert input_size[i] % self.patch_size[i] == 0, "Input size must be divisible by patch size"
input_size = [input_size[i] // self.patch_size[i] for i in range(3)]
return input_size
@MODELS.register_module()
class VideoAutoencoderKLTemporalDecoder(nn.Module):
def __init__(self, from_pretrained=None):
super().__init__()
self.module = AutoencoderKLTemporalDecoder.from_pretrained(from_pretrained)
self.out_channels = self.module.config.latent_channels
self.patch_size = (1, 8, 8)
def encode(self, x):
raise NotImplementedError
def decode(self, x):
B, _, T = x.shape[:3]
x = rearrange(x, "B C T H W -> (B T) C H W")
x = self.module.decode(x / 0.18215, num_frames=T).sample
x = rearrange(x, "(B T) C H W -> B C T H W", B=B)
return x
def get_latent_size(self, input_size):
for i in range(3):
assert input_size[i] % self.patch_size[i] == 0, "Input size must be divisible by patch size"
input_size = [input_size[i] // self.patch_size[i] for i in range(3)]
return input_size
import torch
import torch.nn as nn
from opensora.models.vsr.safmn_arch import SAFMN
import torch.nn.functional as F
from einops import rearrange
from timm.models.vision_transformer import Mlp
from opensora.models.layers.blocks import (
Attention,
MultiHeadCrossAttention,
PatchEmbed3D,
get_1d_sincos_pos_embed,
get_2d_sincos_pos_embed,
get_layernorm,
)
# high pass filter
def high_pass_filter(x, kernel_size=21):
"""
对输入张量进行高通滤波,提取高频和低频部分。
参数:
x (torch.Tensor): 形状为 [B, C, T, H, W] 的输入张量,值范围在 [-1, 1]。
kernel_size (int): 高斯核的大小。
返回:
high_freq (torch.Tensor): 高频部分,形状与 x 相同。
low_freq (torch.Tensor): 低频部分,形状与 x 相同。
"""
# 计算sigma值
sigma = kernel_size / 6
# 确定输入张量的设备
device, dtype = x.device, x.dtype
# 转换维度 [B, C, T, H, W] -> [B*T, C, H, W]
B, C, T, H, W = x.shape
x_reshaped = x.contiguous().view(B * T, C, H, W)
# 创建高斯核
def get_gaussian_kernel(kernel_size, sigma):
axis = torch.arange(kernel_size, dtype=dtype, device=device) - kernel_size // 2
gaussian = torch.exp(-0.5 * (axis / sigma) ** 2)
gaussian /= gaussian.sum()
return gaussian
gaussian_1d = get_gaussian_kernel(kernel_size, sigma)
gaussian_2d = torch.outer(gaussian_1d, gaussian_1d)
gaussian_3d = gaussian_2d.unsqueeze(0).unsqueeze(0) # [1, 1, H, W]
# 将高斯核扩展到四维
gaussian_kernel = gaussian_3d.expand(C, 1, kernel_size, kernel_size)
# 使用F.conv2d进行卷积操作
padding = kernel_size // 2
# 计算低频部分
low_freq_reshaped = F.conv2d(x_reshaped, gaussian_kernel, padding=padding, groups=C)
# 计算高频部分
high_freq_reshaped = x_reshaped - low_freq_reshaped
# 转换回原始维度 [B*T, C, H, W] -> [B, C, T, H, W]
low_freq = low_freq_reshaped.view(B, C, T, H, W)
high_freq = high_freq_reshaped.view(B, C, T, H, W)
return high_freq, low_freq
# depth-wise separable convoluiton
class DepthWiseSeparableResBlock(nn.Module):
def __init__(self, in_channels, kernel_size=3, stride=1, padding=1, bias=False):
super(DepthWiseSeparableResBlock, self).__init__()
self.dwconv1 = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, bias=bias) # groups=in_channels,
# self.conv1 = nn.Conv2d(in_channels, in_channels, 1, bias=bias)
self.gelu = nn.GELU()
self.dwconv2 = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, bias=bias) # groups=in_channels,
# self.conv2 = nn.Conv2d(in_channels, in_channels, 1, bias=bias)
def forward(self, x):
residual = x
out = self.dwconv1(x)
# out = self.conv1(out)
out = self.gelu(out)
out = self.dwconv2(out)
# out = self.conv2(out)
out += residual
return out
# temporal transformer block
class TemporalTransformerBlock(nn.Module):
def __init__(self):
super(TemporalTransformerBlock, self).__init__()
# temporal norm
self.temporal_norm = get_layernorm(1152, eps=1e-6, affine=False, use_kernel=True)
# temporal self-attention
self.temporal_attn = Attention(
dim=1152,
num_heads=16,
qkv_bias=True,
enable_flashattn=True)
# ffn
self.temporal_ffn = Mlp(in_features=1152, hidden_features=4608, out_features=1152, act_layer=nn.GELU)
def forward(self, x):
residual = x
out = self.temporal_norm(x)
out = self.temporal_attn(out)
out = self.temporal_ffn(out)
out += residual
return out
# frequency-decoupled information extractor
class FrequencyDecoupledInfoExtractor(nn.Module):
def __init__(self, in_channels, hidden_channels, kernel_size=3, stride=1, padding=1, bias=True):
super(FrequencyDecoupledInfoExtractor, self).__init__()
### spatial branch ###
self.safmn = SAFMN(dim=128, n_blocks=16, ffn_scale=2.0, upscaling_factor=4, use_res=True)
state_dict = torch.load('/mnt/bn/videodataset/VSR/pretrained_models/SAFMN_L_Real_LSDIR_x4-v2.pth')
self.safmn.load_state_dict(state_dict['params_ema'], strict=True)
# high-frequency branch
# self.hf_convin = nn.Conv2d(in_channels, hidden_channels, kernel_size, stride, padding, bias=bias)
# self.hf_convout = nn.Conv2d(hidden_channels, in_channels, kernel_size, stride, padding, bias=bias)
# hf_layer = []
# for i in range(8):
# hf_layer.append(DepthWiseSeparableResBlock(hidden_channels, kernel_size, stride=1, padding=1, bias=bias))
# self.hf_body = nn.Sequential(*hf_layer)
self.safmn1 = SAFMN(dim=72, n_blocks=8, ffn_scale=2.0, upscaling_factor=1, in_dim=6, use_res=True)
# low-frequency branch
# self.lf_convin = nn.Conv2d(in_channels, hidden_channels, kernel_size, stride, padding, bias=bias)
# self.lf_convout = nn.Conv2d(hidden_channels, in_channels, kernel_size, stride, padding, bias=bias)
# lf_layer = []
# for i in range(8):
# lf_layer.append(DepthWiseSeparableResBlock(hidden_channels, kernel_size, stride=1, padding=1, bias=bias))
# self.lf_body = nn.Sequential(*lf_layer)
self.safmn2 = SAFMN(dim=72, n_blocks=8, ffn_scale=2.0, upscaling_factor=1, in_dim=6, use_res=True)
### temporal branch ###
layer = []
for i in range(3):
layer.append(TemporalTransformerBlock())
self.temporal_body = nn.Sequential(*layer)
def get_temporal_pos_embed(self):
pos_embed = get_1d_sincos_pos_embed(
embed_dim=1152,
length=16,
scale=1.0,
)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)
return pos_embed
def spatial_forward(self, x):
with torch.no_grad():
x = rearrange(x, 'B C T H W -> (B T) C H W')
x = F.interpolate(x, scale_factor=1/4, mode='bilinear')
clean_image = self.safmn(x)
clean_image = rearrange(clean_image, '(B T) C H W -> B C T H W', T=16)
high_freq, low_freq = high_pass_filter(clean_image)
fea_decouple = torch.cat([high_freq, low_freq], dim=1)
fea_decouple = rearrange(fea_decouple, 'B C T H W -> (B T) C H W')
# high-frequency branch
# hf_out = self.hf_convin(high_freq)
# hf_out = self.hf_body(hf_out)
# hf_out = self.hf_convout(hf_out) + high_freq
hf_out = self.safmn1(fea_decouple)
hf_out = rearrange(hf_out, '(B T) C H W -> B C T H W', T=16)
# low-frequency branch
# lf_out = self.lf_convin(low_freq)
# lf_out = self.lf_body(lf_out)
# lf_out = self.lf_convout(lf_out) + low_freq
lf_out = self.safmn2(fea_decouple)
lf_out = rearrange(lf_out, '(B T) C H W -> B C T H W', T=16)
return clean_image, hf_out, lf_out
def temporal_forward(self, x):
x = rearrange(x, "B (T S) C -> (B S) T C", T=16)
tpe = self.get_temporal_pos_embed().to(x.device, x.dtype)
x = x + tpe
x = self.temporal_body(x)
x = rearrange(x, "(B S) T C -> B (T S) C", S=256)
return x
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import ops
# from basicsr.utils.registry import ARCH_REGISTRY
# Layer Norm
class LayerNorm(nn.Module):
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_first"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError
self.normalized_shape = (normalized_shape, )
def forward(self, x):
if self.data_format == "channels_last":
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
elif self.data_format == "channels_first":
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
# SE
class SqueezeExcitation(nn.Module):
def __init__(self, dim, shrinkage_rate=0.25):
super().__init__()
hidden_dim = int(dim * shrinkage_rate)
self.gate = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(dim, hidden_dim, 1, 1, 0),
nn.GELU(),
nn.Conv2d(hidden_dim, dim, 1, 1, 0),
nn.Sigmoid(),
)
def forward(self, x):
return x * self.gate(x)
# Channel MLP: Conv1*1 -> Conv1*1
class ChannelMLP(nn.Module):
def __init__(self, dim, growth_rate=2.0):
super().__init__()
hidden_dim = int(dim * growth_rate)
self.mlp = nn.Sequential(
nn.Conv2d(dim, hidden_dim, 1, 1, 0),
nn.GELU(),
nn.Conv2d(hidden_dim, dim, 1, 1, 0)
)
def forward(self, x):
return self.mlp(x)
# MBConv: Conv1*1 -> DW Conv3*3 -> [SE] -> Conv1*1
class MBConv(nn.Module):
def __init__(self, dim, growth_rate=2.0):
super().__init__()
hidden_dim = int(dim * growth_rate)
self.mbconv = nn.Sequential(
nn.Conv2d(dim, hidden_dim, 1, 1, 0),
nn.GELU(),
nn.Conv2d(hidden_dim, hidden_dim, 3, 1, 1, groups=hidden_dim),
nn.GELU(),
SqueezeExcitation(hidden_dim),
nn.Conv2d(hidden_dim, dim, 1, 1, 0)
)
def forward(self, x):
return self.mbconv(x)
# CCM
class CCM(nn.Module):
def __init__(self, dim, growth_rate=2.0):
super().__init__()
hidden_dim = int(dim * growth_rate)
self.ccm = nn.Sequential(
nn.Conv2d(dim, hidden_dim, 3, 1, 1),
nn.GELU(),
nn.Conv2d(hidden_dim, dim, 1, 1, 0)
)
def forward(self, x):
return self.ccm(x)
# SAFM
class SAFM(nn.Module):
def __init__(self, dim, n_levels=4):
super().__init__()
self.n_levels = n_levels
chunk_dim = dim // n_levels
# Spatial Weighting
self.mfr = nn.ModuleList([nn.Conv2d(chunk_dim, chunk_dim, 3, 1, 1, groups=chunk_dim) for i in range(self.n_levels)])
# # Feature Aggregation
self.aggr = nn.Conv2d(dim, dim, 1, 1, 0)
# Activation
self.act = nn.GELU()
def forward(self, x):
h, w = x.size()[-2:]
xc = x.chunk(self.n_levels, dim=1)
out = []
for i in range(self.n_levels):
if i > 0:
p_size = (h//2**i, w//2**i)
s = F.adaptive_max_pool2d(xc[i], p_size)
s = self.mfr[i](s)
s = F.interpolate(s, size=(h, w), mode='nearest')
else:
s = self.mfr[i](xc[i])
out.append(s)
out = self.aggr(torch.cat(out, dim=1))
out = self.act(out) * x
return out
class AttBlock(nn.Module):
def __init__(self, dim, ffn_scale=2.0):
super().__init__()
self.norm1 = LayerNorm(dim)
self.norm2 = LayerNorm(dim)
# Multiscale Block
self.safm = SAFM(dim)
# Feedforward layer
self.ccm = CCM(dim, ffn_scale)
def forward(self, x):
x = self.safm(self.norm1(x)) + x
x = self.ccm(self.norm2(x)) + x
return x
# @ARCH_REGISTRY.register()
class SAFMN(nn.Module):
def __init__(self, dim, n_blocks=8, ffn_scale=2.0, upscaling_factor=4, in_dim=3, use_res=True):
super().__init__()
self.use_res = use_res
self.to_feat = nn.Conv2d(in_dim, dim, 3, 1, 1)
self.feats = nn.Sequential(*[AttBlock(dim, ffn_scale) for _ in range(n_blocks)])
self.to_img = nn.Sequential(
nn.Conv2d(dim, 3 * upscaling_factor**2, 3, 1, 1),
nn.PixelShuffle(upscaling_factor)
)
def forward(self, x):
x = self.to_feat(x)
if self.use_res:
x = self.feats(x) + x
else:
x = self.feats(x)
x = self.to_img(x)
return x
# if __name__== '__main__':
# #############Test Model Complexity #############
# from fvcore.nn import flop_count_table, FlopCountAnalysis, ActivationCountAnalysis
# # x = torch.randn(1, 3, 640, 360)
# # x = torch.randn(1, 3, 427, 240)
# x = torch.randn(1, 3, 320, 180)
# # x = torch.randn(1, 3, 256, 256)
# model = SAFMN(dim=36, n_blocks=8, ffn_scale=2.0, upscaling_factor=4)
# # model = SAFMN(dim=36, n_blocks=12, ffn_scale=2.0, upscaling_factor=2)
# print(model)
# print(f'params: {sum(map(lambda x: x.numel(), model.parameters()))}')
# print(flop_count_table(FlopCountAnalysis(model, x), activations=ActivationCountAnalysis(model, x)))
# output = model(x)
# print(output.shape)
\ No newline at end of file
import torch
import torch.nn as nn
import xformers.ops
# spatial feature refiner
class SpatialFeatureRefiner(nn.Module):
def __init__(self, hidden_channels):
super(SpatialFeatureRefiner, self).__init__()
# high-frequency branch
self.hf_linear = nn.Linear(hidden_channels, hidden_channels * 2)
# low-frequency branch
self.lf_linear = nn.Linear(hidden_channels, hidden_channels * 2)
# fusion
self.gelu = nn.GELU()
self.fusion_linear = nn.Linear(hidden_channels * 2, hidden_channels)
def forward(self, hf_feature, lf_feature, x):
# high-frequency branch
hf_feature = self.hf_linear(hf_feature)
scale_hf, shift_hf = hf_feature.chunk(2, dim=-1)
x_hf = x * scale_hf + shift_hf
# low-frequency branch
lf_feature = self.lf_linear(lf_feature)
scale_lf, shift_lf = lf_feature.chunk(2, dim=-1)
x_lf = x * scale_lf + shift_lf
# fusion
x_fusion = torch.cat([x_hf, x_lf], dim=-1)
x_fusion = self.gelu(x_fusion)
x_fusion = self.fusion_linear(x_fusion)
return x_fusion
# low-frequency temporal guider
class LFTemporalGuider(nn.Module):
def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0):
super(LFTemporalGuider, self).__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.q_linear = nn.Linear(d_model, d_model)
self.kv_linear = nn.Linear(d_model, d_model * 2)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(d_model, d_model)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, cond, mask=None):
# query/value: img tokens; key: condition; mask: if padding tokens
B, N, C = x.shape
q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
k, v = kv.unbind(2)
attn_bias = None
if mask is not None:
attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask)
x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
x = x.view(B, -1, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
\ No newline at end of file
from copy import deepcopy
import torch.nn as nn
from mmengine.registry import Registry
def build_module(module, builder, **kwargs):
"""Build module from config or return the module itself.
Args:
module (Union[dict, nn.Module]): The module to build.
builder (Registry): The registry to build module.
*args, **kwargs: Arguments passed to build function.
Returns:
Any: The built module.
"""
if isinstance(module, dict):
cfg = deepcopy(module)
for k, v in kwargs.items():
cfg[k] = v
return builder.build(cfg)
elif isinstance(module, nn.Module):
return module
elif module is None:
return None
else:
raise TypeError(f"Only support dict and nn.Module, but got {type(module)}.")
MODELS = Registry(
"model",
locations=["opensora.models"],
)
SCHEDULERS = Registry(
"scheduler",
locations=["opensora.schedulers"],
)
DATASETS = Registry(
"dataset",
locations=["opensora.datasets"],
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment