Commit 1f5da520 authored by yangzhong's avatar yangzhong
Browse files

git init

parents
Pipeline #3144 failed with stages
in 0 seconds
import matplotlib.pyplot as plt
import torch
import numpy as np
from PIL import Image
import cv2
import torchvision.transforms
"""
1. Gaussian noise
generate_gaussian_noise
random_generate_gaussian_noise
random_add_gaussian_noise
add_gaussian_noise
2. Poisson noise
random_add_poisson_noise
random_generate_poisson_noise
generate_poisson_noise
add_poisson_noise
"""
'''
generate_gaussian_noise
add_gaussian_noise
generate_gaussian_noise_pt
add_gaussian_noise_pt
random_generate_gaussian_noise
random_add_gaussian_noise
random_generate_gaussian_noise_pt
random_add_gaussian_noise_pt
'''
# -------------------------------------------------------------------- #
# --------------------random_add_gaussian_noise----------------------- #
# -------------------------------------------------------------------- #
def random_add_gaussian_noise_pt(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
noise = random_generate_gaussian_noise_pt(img, sigma_range, gray_prob)
out = img + noise
if clip and rounds:
out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
elif clip:
out = torch.clamp(out, 0, 1)
elif rounds:
out = (out * 255.0).round() / 255.
return out
def random_generate_gaussian_noise_pt(img, sigma_range=(0, 10), gray_prob=0):
sigma = torch.rand(
img.size(0), dtype=img.dtype, device=img.device) * (sigma_range[1] - sigma_range[0]) + sigma_range[0]
gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device)
gray_noise = (gray_noise < gray_prob).float()
return generate_gaussian_noise_pt(img, sigma, gray_noise)
def generate_gaussian_noise_pt(img, sigma=10, gray_noise=0):
"""Add Gaussian noise (PyTorch version).
Args:
img (Tensor): Shape (b, c, h, w), range[0, 1], float32.
scale (float | Tensor): Noise scale. Default: 1.0.
Returns:
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
float32.
"""
b, _, h, w = img.size()
if not isinstance(sigma, (float, int)):
sigma = sigma.view(img.size(0), 1, 1, 1)
if isinstance(gray_noise, (float, int)):
cal_gray_noise = gray_noise > 0
else:
gray_noise = gray_noise.view(b, 1, 1, 1)
cal_gray_noise = torch.sum(gray_noise) > 0
if cal_gray_noise:
noise_gray = torch.randn(*img.size()[2:4], dtype=img.dtype, device=img.device) * sigma / 255.
noise_gray = noise_gray.view(b, 1, h, w)
# always calculate color noise
noise = torch.randn(*img.size(), dtype=img.dtype, device=img.device) * sigma / 255.
if cal_gray_noise:
noise = noise * (1 - gray_noise) + noise_gray * gray_noise
return noise
def add_gaussian_noise_pt(img, sigma=10, gray_noise=0, clip=True, rounds=False):
"""Add Gaussian noise (PyTorch version).
Args:
img (Tensor): Shape (b, c, h, w), range[0, 1], float32.
scale (float | Tensor): Noise scale. Default: 1.0.
Returns:
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
float32.
"""
noise = generate_gaussian_noise_pt(img, sigma, gray_noise)
out = img + noise
if clip and rounds:
out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
elif clip:
out = torch.clamp(out, 0, 1)
elif rounds:
out = (out * 255.0).round() / 255.
return out
# -------------------------------------------------------------------- #
# --------------------random_add_poisson_noise------------------------ #
# -------------------------------------------------------------------- #
def random_add_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
noise = random_generate_poisson_noise_pt(img, scale_range, gray_prob)
out = img + noise
if clip and rounds:
out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
elif clip:
out = torch.clamp(out, 0, 1)
elif rounds:
out = (out * 255.0).round() / 255.
return out
def random_generate_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0):
scale = torch.rand(
img.size(0), dtype=img.dtype, device=img.device) * (scale_range[1] - scale_range[0]) + scale_range[0]
gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device)
gray_noise = (gray_noise < gray_prob).float()
return generate_poisson_noise_pt(img, scale, gray_noise)
def generate_poisson_noise_pt(img, scale=1.0, gray_noise=0):
"""Generate a batch of poisson noise (PyTorch version)
Args:
img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32.
scale (float | Tensor): Noise scale. Number or Tensor with shape (b).
Default: 1.0.
gray_noise (float | Tensor): 0-1 number or Tensor with shape (b).
0 for False, 1 for True. Default: 0.
Returns:
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
float32.
"""
b, _, h, w = img.size()
if isinstance(gray_noise, (float, int)):
cal_gray_noise = gray_noise > 0
else:
gray_noise = gray_noise.view(b, 1, 1, 1)
cal_gray_noise = torch.sum(gray_noise) > 0
if cal_gray_noise:
#img_gray = rgb_to_grayscale(img, num_output_channels=1)
img_gray=img[:,0,:,:] #size: BHW
img_gray=torch.unsqueeze(img_gray,1)
# round and clip image for counting vals correctly
img_gray = torch.clamp((img_gray * 255.0).round(), 0, 255) / 255.
# use for-loop to get the unique values for each sample
vals_list = [len(torch.unique(img_gray[i, :, :, :])) for i in range(b)]
vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list]
vals = img_gray.new_tensor(vals_list).view(b, 1, 1, 1)
out = torch.poisson(img_gray * vals) / vals
noise_gray = out - img_gray
noise_gray = noise_gray.expand(b, 3, h, w)
# always calculate color noise
# round and clip image for counting vals correctly
img = torch.clamp((img * 255.0).round(), 0, 255) / 255.
# use for-loop to get the unique values for each sample
vals_list = [len(torch.unique(img[i, :, :, :])) for i in range(b)]
vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list]
vals = img.new_tensor(vals_list).view(b, 1, 1, 1)
out = torch.poisson(img * vals) / vals
noise = out - img
if cal_gray_noise:
noise = noise * (1 - gray_noise) + noise_gray * gray_noise
if not isinstance(scale, (float, int)):
scale = scale.view(b, 1, 1, 1)
return noise * scale
def add_poisson_noise_pt(img, scale=1.0, clip=True, rounds=False, gray_noise=0):
"""Add poisson noise to a batch of images (PyTorch version).
Args:
img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32.
scale (float | Tensor): Noise scale. Number or Tensor with shape (b).
Default: 1.0.
gray_noise (float | Tensor): 0-1 number or Tensor with shape (b).
0 for False, 1 for True. Default: 0.
Returns:
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
float32.
"""
noise = generate_poisson_noise_pt(img, scale, gray_noise)
out = img + noise
if clip and rounds:
out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
elif clip:
out = torch.clamp(out, 0, 1)
elif rounds:
out = (out * 255.0).round() / 255.
return out
if __name__ == '__main__':
gaussian_noise_prob2 = 0.5
noise_range2 = [1, 25]
poisson_scale_range2 = [0.05, 2.5]
gray_noise_prob2 = 0.4
# img=cv2.imread('../qj.png')
# img=np.float32(img / 255.)
# noise=random_generate_poisson_noise_pt(img,noise_range2,gray_noise_prob2)
# img_noise=random_add_gaussian_noise_pt(img,noise_range2,gray_noise_prob2)
# print(noise.shape)
#
# img_noise=np.uint8((img_noise.clip(0,1)*255.).round())
img=Image.open('../dog.jpg')
img=torchvision.transforms.ToTensor()(img)
img=img.unsqueeze(0)
img_noise=random_add_poisson_noise_pt(img,poisson_scale_range2,gray_noise_prob2)
img_noise=img_noise.squeeze(0).permute(1,2,0).detach().numpy()
plt.imshow(img_noise)
plt.show()
\ No newline at end of file
import torch
from torch.nn import functional as F
import cv2
import random
import numpy as np
'''
enum InterpolationFlags
{
``'bilinear'`` | ``'bicubic'`` | ``'area'``
};
'''
def random_resizing(image,updown_type,resize_prob,mode_list,resize_range):
b, c, h, w= image.shape
updown_type = random.choices(updown_type, resize_prob)[0] #choices返回list ["up"],所以要通过 [0] 取list第一个元素
mode = random.choice(mode_list)
if updown_type == "up":
scale = np.random.uniform(1, resize_range[1])
elif updown_type == "down":
scale = np.random.uniform(resize_range[0], 1)
else:
scale = 1
image = F.interpolate(image,scale_factor=scale,mode=random.choice(['area','bilinear','bicubic']))
#image = cv2.resize(image, (w, h), interpolation=flags)
image = torch.clamp(image, 0.0, 1.0)
return image
import os
import re
from typing import Iterator, Optional
from torch.distributed import ProcessGroup
import numpy as np
import pandas as pd
import requests
import torch
import cv2
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader
from torch.distributed.distributed_c10d import _get_default_group
from torch.utils.data.distributed import DistributedSampler
from torchvision.io import write_video
from torchvision.utils import save_image
import random
from . import video_transforms
from .wavelet_color_fix import adain_color_fix
VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv")
regex = re.compile(
r"^(?:http|ftp)s?://" # http:// or https://
r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|" # domain...
r"localhost|" # localhost...
r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})" # ...or ip
r"(?::\d+)?" # optional port
r"(?:/?|[/?]\S+)$",
re.IGNORECASE,
)
def is_url(url):
return re.match(regex, url) is not None
def read_file(input_path):
if input_path.endswith(".csv"):
return pd.read_csv(input_path)
elif input_path.endswith(".parquet"):
return pd.read_parquet(input_path)
else:
raise NotImplementedError(f"Unsupported file format: {input_path}")
def download_url(input_path):
output_dir = "cache"
if not os.path.exists(output_dir):
os.makedirs(output_dir)
base_name = os.path.basename(input_path)
output_path = os.path.join(output_dir, base_name)
img_data = requests.get(input_path).content
with open(output_path, "wb") as handler:
handler.write(img_data)
print(f"URL {input_path} downloaded to {output_path}")
return output_path
def temporal_random_crop(vframes, num_frames, frame_interval):
temporal_sample = video_transforms.TemporalRandomCrop(num_frames * frame_interval)
total_frames = len(vframes)
start_frame_ind, end_frame_ind = temporal_sample(total_frames)
assert end_frame_ind - start_frame_ind >= num_frames
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, num_frames, dtype=int)
video = vframes[frame_indice]
return video
def compute_bidirectional_optical_flow(video_frames):
video_frames = video_frames.permute(0, 2, 3, 1).numpy() # T C H W -> T H W C
T, H, W, _ = video_frames.shape
bidirectional_flow = torch.zeros((2, T - 1, H, W))
for t in range(T - 1):
prev_frame = cv2.cvtColor(video_frames[t], cv2.COLOR_RGB2GRAY)
next_frame = cv2.cvtColor(video_frames[t + 1], cv2.COLOR_RGB2GRAY)
# 计算前向光流
flow_forward = cv2.calcOpticalFlowFarneback(prev_frame, next_frame, None, 0.5, 3, 15, 3, 5, 1.2, 0)
# 计算反向光流
flow_backward = cv2.calcOpticalFlowFarneback(next_frame, prev_frame, None, 0.5, 3, 15, 3, 5, 1.2, 0)
# 合并前向和反向光流图
bidirectional_flow[:, t] = torch.from_numpy((flow_forward + flow_backward).reshape(2, H, W))
return bidirectional_flow
# 定义模糊函数
def blur_video(video, kernel_size=(21, 21), sigma=21):
"""
对视频的每一帧进行高斯模糊处理
Args:
video (torch.Tensor): 输入视频,维度为 [T, C, H, W]
kernel_size (tuple): 模糊核大小,默认为 (5, 5)
sigma (float): 高斯核标准差,默认为 0
Returns:
torch.Tensor: 处理后的视频
"""
blurred_frames = []
for frame in video:
# 转换成 numpy 格式,大小为 (H, W, C)
frame_np = frame.permute(1, 2, 0).numpy()
# 使用 OpenCV 进行高斯模糊处理
blurred_frame = cv2.GaussianBlur(frame_np, kernel_size, sigma)
# 转换回 PyTorch 格式,大小为 (C, H, W)
blurred_frame = torch.from_numpy(blurred_frame).permute(2, 0, 1)
blurred_frames.append(blurred_frame)
# 拼接处理后的帧成为视频,维度为 [T, C, H, W]
return torch.stack(blurred_frames)
def get_transforms_video(name="center", image_size=(256, 256)):
if name is None:
return None
elif name == "center":
assert image_size[0] == image_size[1], "image_size must be square for center crop"
transform_video = transforms.Compose(
[
video_transforms.ToTensorVideo(), # TCHW
# video_transforms.RandomHorizontalFlipVideo(),
video_transforms.UCFCenterCropVideo(image_size[0]),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
elif name == "resize_crop":
transform_video = transforms.Compose(
[
video_transforms.ToTensorVideo(), # TCHW
video_transforms.ResizeCrop(image_size),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
elif name == "direct_crop":
transform_video = transforms.Compose(
[
video_transforms.ToTensorVideo(), # TCHW
video_transforms.RandomCrop(image_size),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
else:
raise NotImplementedError(f"Transform {name} not implemented")
return transform_video
def get_transforms_image(name="center", image_size=(256, 256)):
if name is None:
return None
elif name == "center":
assert image_size[0] == image_size[1], "Image size must be square for center crop"
transform = transforms.Compose(
[
transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, image_size[0])),
# transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
elif name == "resize_crop":
transform = transforms.Compose(
[
transforms.Lambda(lambda pil_image: resize_crop_to_fill(pil_image, image_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
else:
raise NotImplementedError(f"Transform {name} not implemented")
return transform
def read_image_from_path(path, transform=None, transform_name="center", num_frames=1, image_size=(256, 256)):
image = pil_loader(path)
if transform is None:
transform = get_transforms_image(image_size=image_size, name=transform_name)
image = transform(image)
video = image.unsqueeze(0).repeat(num_frames, 1, 1, 1)
video = video.permute(1, 0, 2, 3)
return video
def read_video_from_path(path, transform=None, transform_name="center", image_size=(256, 256)):
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
if transform is None:
transform = get_transforms_video(image_size=image_size, name=transform_name)
video = transform(vframes) # T C H W
video = video.permute(1, 0, 2, 3)
return video
def read_from_path(path, image_size, transform_name="center"):
if is_url(path):
path = download_url(path)
ext = os.path.splitext(path)[-1].lower()
if ext.lower() in VID_EXTENSIONS:
return read_video_from_path(path, image_size=image_size, transform_name=transform_name)
else:
assert ext.lower() in IMG_EXTENSIONS, f"Unsupported file format: {ext}"
return read_image_from_path(path, image_size=image_size, transform_name=transform_name)
def save_sample(x, fps=8, save_path=None, normalize=True, value_range=(-1, 1), force_video=False, align_method=None, validation_video=None):
os.makedirs(os.path.dirname(save_path), exist_ok=True)
"""
Args:
x (Tensor): shape [C, T, H, W]
"""
assert x.ndim == 4
if not force_video and x.shape[1] == 1: # T = 1: save as image
save_path += ".png"
x = x.squeeze(1)
save_image([x], save_path, normalize=normalize, value_range=value_range)
else:
save_path += ".mp4"
if normalize:
low, high = value_range
x.clamp_(min=low, max=high)
x.sub_(low).div_(max(high - low, 1e-5))
if align_method:
x = adain_color_fix(x, validation_video)
x = x.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 3, 0).to("cpu", torch.uint8)
write_video(save_path, x, fps=int(fps), video_codec="h264")
# print(f"Saved to {save_path}")
return save_path
def center_crop_arr(pil_image, image_size):
"""
Center cropping implementation from ADM.
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
"""
while min(*pil_image.size) >= 2 * image_size:
pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)
scale = image_size / min(*pil_image.size)
pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
arr = np.array(pil_image)
crop_y = (arr.shape[0] - image_size) // 2
crop_x = (arr.shape[1] - image_size) // 2
return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size])
class StatefulDistributedSampler(DistributedSampler):
def __init__(
self,
dataset: Dataset,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
) -> None:
super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last)
self.start_index: int = 0
def __iter__(self) -> Iterator:
iterator = super().__iter__()
indices = list(iterator)
indices = indices[self.start_index :]
return iter(indices)
def __len__(self) -> int:
return self.num_samples - self.start_index
def set_start_index(self, start_index: int) -> None:
self.start_index = start_index
def prepare_dataloader(
dataset,
batch_size,
shuffle=False,
seed=1024,
drop_last=False,
pin_memory=False,
num_workers=0,
process_group: Optional[ProcessGroup] = None,
**kwargs,
):
r"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by
`torch.utils.data.DataLoader` and `StatefulDistributedSampler`.
Args:
dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
seed (int, optional): Random worker seed for sampling, defaults to 1024.
add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
is not divisible by the batch size. If False and the size of dataset is not divisible by
the batch size, then the last batch will be smaller, defaults to False.
pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
`DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.
Returns:
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
"""
_kwargs = kwargs.copy()
process_group = process_group or _get_default_group()
sampler = StatefulDistributedSampler(
dataset, num_replicas=process_group.size(), rank=process_group.rank(), shuffle=shuffle
)
# Deterministic dataloader
def seed_worker(worker_id):
worker_seed = seed
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
random.seed(worker_seed)
return DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
worker_init_fn=seed_worker,
drop_last=drop_last,
pin_memory=pin_memory,
num_workers=num_workers,
**_kwargs,
)
\ No newline at end of file
# Copyright 2024 Vchitect/Latte
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.# Modified from Latte
# - This file is adapted from https://github.com/Vchitect/Latte/blob/main/datasets/video_transforms.py
import numbers
import random
import torch
def _is_tensor_video_clip(clip):
if not torch.is_tensor(clip):
raise TypeError("clip should be Tensor. Got %s" % type(clip))
if not clip.ndimension() == 4:
raise ValueError("clip should be 4D. Got %dD" % clip.dim())
return True
def center_crop_arr(pil_image, image_size):
"""
Center cropping implementation from ADM.
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
"""
while min(*pil_image.size) >= 2 * image_size:
pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)
scale = image_size / min(*pil_image.size)
pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
arr = np.array(pil_image)
crop_y = (arr.shape[0] - image_size) // 2
crop_x = (arr.shape[1] - image_size) // 2
return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size])
def crop(clip, i, j, h, w):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
"""
if len(clip.size()) != 4:
raise ValueError("clip should be a 4D tensor")
return clip[..., i : i + h, j : j + w]
def random_crop(clip, crop_size):
"""
Args:
clip (torch.Tensor): Video clip to be cropped. Size is (T, C, H, W)
crop_size (tuple): Desired output size (h, w)
Returns:
torch.Tensor: Cropped video of size (T, C, h, w)
"""
if len(clip.size()) != 4:
raise ValueError("clip should be a 4D tensor")
_, _, H, W = clip.shape
th, tw = crop_size
if th > H or tw > W:
raise ValueError("Crop size should be smaller than video dimensions")
i = torch.randint(0, H - th + 1, size=(1,)).item()
j = torch.randint(0, W - tw + 1, size=(1,)).item()
return crop(clip, i, j, th, tw)
def resize(clip, target_size, interpolation_mode):
if len(target_size) != 2:
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False)
def resize_scale(clip, target_size, interpolation_mode):
if len(target_size) != 2:
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
H, W = clip.size(-2), clip.size(-1)
scale_ = target_size[0] / min(H, W)
return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False)
def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
"""
Do spatial cropping and resizing to the video clip
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
i (int): i in (i,j) i.e coordinates of the upper left corner.
j (int): j in (i,j) i.e coordinates of the upper left corner.
h (int): Height of the cropped region.
w (int): Width of the cropped region.
size (tuple(int, int)): height and width of resized clip
Returns:
clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W)
"""
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
clip = crop(clip, i, j, h, w)
clip = resize(clip, size, interpolation_mode)
return clip
def center_crop(clip, crop_size):
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
h, w = clip.size(-2), clip.size(-1)
th, tw = crop_size
if h < th or w < tw:
clip = torch.nn.functional.interpolate(clip, size=(th, tw), mode="bilinear", align_corners=False)
h, w = clip.size(-2), clip.size(-1)
#raise ValueError("height and width must be no smaller than crop_size")
i = int(round((h - th) / 2.0))
j = int(round((w - tw) / 2.0))
return crop(clip, i, j, th, tw)
def center_crop_using_short_edge(clip):
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
h, w = clip.size(-2), clip.size(-1)
if h < w:
th, tw = h, h
i = 0
j = int(round((w - tw) / 2.0))
else:
th, tw = w, w
i = int(round((h - th) / 2.0))
j = 0
return crop(clip, i, j, th, tw)
def random_shift_crop(clip):
"""
Slide along the long edge, with the short edge as crop size
"""
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
h, w = clip.size(-2), clip.size(-1)
if h <= w:
short_edge = h
else:
short_edge = w
th, tw = short_edge, short_edge
i = torch.randint(0, h - th + 1, size=(1,)).item()
j = torch.randint(0, w - tw + 1, size=(1,)).item()
return crop(clip, i, j, th, tw)
def to_tensor(clip):
"""
Convert tensor data type from uint8 to float, divide value by 255.0 and
permute the dimensions of clip tensor
Args:
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
Return:
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
"""
_is_tensor_video_clip(clip)
if not clip.dtype == torch.uint8:
raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype))
# return clip.float().permute(3, 0, 1, 2) / 255.0
return clip.float() / 255.0
def normalize(clip, mean, std, inplace=False):
"""
Args:
clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
mean (tuple): pixel RGB mean. Size is (3)
std (tuple): pixel standard deviation. Size is (3)
Returns:
normalized clip (torch.tensor): Size is (T, C, H, W)
"""
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
if not inplace:
clip = clip.clone()
mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
# print(mean)
std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
return clip
def hflip(clip):
"""
Args:
clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
Returns:
flipped clip (torch.tensor): Size is (T, C, H, W)
"""
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
return clip.flip(-1)
class RandomCropVideo:
def __init__(self, size):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: randomly cropped video clip.
size is (T, C, OH, OW)
"""
i, j, h, w = self.get_params(clip)
return crop(clip, i, j, h, w)
def get_params(self, clip):
h, w = clip.shape[-2:]
th, tw = self.size
if h < th or w < tw:
raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}")
if w == tw and h == th:
return 0, 0, h, w
i = torch.randint(0, h - th + 1, size=(1,)).item()
j = torch.randint(0, w - tw + 1, size=(1,)).item()
return i, j, th, tw
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size})"
class CenterCropResizeVideo:
"""
First use the short side for cropping length,
center crop video, then resize to the specified size
"""
def __init__(
self,
size,
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
if len(size) != 2:
raise ValueError(f"size should be tuple (height, width), instead got {size}")
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: scale resized / center cropped video clip.
size is (T, C, crop_size, crop_size)
"""
clip_center_crop = center_crop_using_short_edge(clip)
clip_center_crop_resize = resize(
clip_center_crop, target_size=self.size, interpolation_mode=self.interpolation_mode
)
return clip_center_crop_resize
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
class UCFCenterCropVideo:
"""
First scale to the specified size in equal proportion to the short edge,
then center cropping
"""
def __init__(
self,
size,
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
if len(size) != 2:
raise ValueError(f"size should be tuple (height, width), instead got {size}")
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: scale resized / center cropped video clip.
size is (T, C, crop_size, crop_size)
"""
clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode)
clip_center_crop = center_crop(clip_resize, self.size)
return clip_center_crop
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
class KineticsRandomCropResizeVideo:
"""
Slide along the long edge, with the short edge as crop size. And resie to the desired size.
"""
def __init__(
self,
size,
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
if len(size) != 2:
raise ValueError(f"size should be tuple (height, width), instead got {size}")
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
clip_random_crop = random_shift_crop(clip)
clip_resize = resize(clip_random_crop, self.size, self.interpolation_mode)
return clip_resize
class CenterCropVideo:
def __init__(
self,
size,
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
if len(size) != 2:
raise ValueError(f"size should be tuple (height, width), instead got {size}")
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: center cropped video clip.
size is (T, C, crop_size, crop_size)
"""
clip_center_crop = center_crop(clip, self.size)
return clip_center_crop
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
class NormalizeVideo:
"""
Normalize the video clip by mean subtraction and division by standard deviation
Args:
mean (3-tuple): pixel RGB mean
std (3-tuple): pixel RGB standard deviation
inplace (boolean): whether do in-place normalization
"""
def __init__(self, mean, std, inplace=False):
self.mean = mean
self.std = std
self.inplace = inplace
def __call__(self, clip):
"""
Args:
clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W)
"""
return normalize(clip, self.mean, self.std, self.inplace)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})"
class ToTensorVideo:
"""
Convert tensor data type from uint8 to float, divide value by 255.0 and
permute the dimensions of clip tensor
"""
def __init__(self):
pass
def __call__(self, clip):
"""
Args:
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
Return:
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
"""
return to_tensor(clip)
def __repr__(self) -> str:
return self.__class__.__name__
class RandomCrop:
"""
Perform random cropping on a video tensor of shape (T, C, H, W).
"""
def __init__(self, crop_size):
"""
Args:
crop_size (tuple): Desired output size (h, w)
"""
self.crop_size = crop_size
def __call__(self, clip):
"""
Args:
clip (torch.tensor, dtype=torch.uint8): Video tensor of size (T, C, H, W)
Returns:
torch.tensor: Cropped video tensor of size (T, C, h, w), dtype=torch.float
"""
return random_crop(clip, self.crop_size)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(crop_size={self.crop_size})"
class RandomHorizontalFlipVideo:
"""
Flip the video clip along the horizontal direction with a given probability
Args:
p (float): probability of the clip being flipped. Default value is 0.5
"""
def __init__(self, p=0.5):
self.p = p
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Size is (T, C, H, W)
Return:
clip (torch.tensor): Size is (T, C, H, W)
"""
if random.random() < self.p:
clip = hflip(clip)
return clip
def __repr__(self) -> str:
return f"{self.__class__.__name__}(p={self.p})"
# ------------------------------------------------------------
# --------------------- Sampling ---------------------------
# ------------------------------------------------------------
class TemporalRandomCrop(object):
"""Temporally crop the given frame indices at a random location.
Args:
size (int): Desired length of frames will be seen in the model.
"""
def __init__(self, size):
self.size = size
def __call__(self, total_frames):
rand_end = max(0, total_frames - self.size - 1)
begin_index = random.randint(0, rand_end)
end_index = min(begin_index + self.size, total_frames)
return begin_index, end_index
if __name__ == "__main__":
import os
import numpy as np
import torchvision.io as io
from torchvision import transforms
from torchvision.utils import save_image
vframes, aframes, info = io.read_video(filename="./v_Archery_g01_c03.avi", pts_unit="sec", output_format="TCHW")
trans = transforms.Compose(
[
ToTensorVideo(),
RandomHorizontalFlipVideo(),
UCFCenterCropVideo(512),
# NormalizeVideo(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
target_video_len = 32
frame_interval = 1
total_frames = len(vframes)
print(total_frames)
temporal_sample = TemporalRandomCrop(target_video_len * frame_interval)
# Sampling video frames
start_frame_ind, end_frame_ind = temporal_sample(total_frames)
# print(start_frame_ind)
# print(end_frame_ind)
assert end_frame_ind - start_frame_ind >= target_video_len
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, target_video_len, dtype=int)
print(frame_indice)
select_vframes = vframes[frame_indice]
print(select_vframes.shape)
print(select_vframes.dtype)
select_vframes_trans = trans(select_vframes)
print(select_vframes_trans.shape)
print(select_vframes_trans.dtype)
select_vframes_trans_int = ((select_vframes_trans * 0.5 + 0.5) * 255).to(dtype=torch.uint8)
print(select_vframes_trans_int.dtype)
print(select_vframes_trans_int.permute(0, 2, 3, 1).shape)
io.write_video("./test.avi", select_vframes_trans_int.permute(0, 2, 3, 1), fps=8)
for i in range(target_video_len):
save_image(
select_vframes_trans[i], os.path.join("./test000", "%04d.png" % i), normalize=True, value_range=(-1, 1)
)
class ResizeCrop:
def __init__(self, size):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
def __call__(self, clip):
clip = resize_crop_to_fill(clip, self.size)
return clip
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size})"
def _is_tensor_video_clip(clip):
if not torch.is_tensor(clip):
raise TypeError("clip should be Tensor. Got %s" % type(clip))
if not clip.ndimension() == 4:
raise ValueError("clip should be 4D. Got %dD" % clip.dim())
return True
def resize(clip, target_size, interpolation_mode):
if len(target_size) != 2:
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False)
def crop(clip, i, j, h, w):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
"""
if len(clip.size()) != 4:
raise ValueError("clip should be a 4D tensor")
return clip[..., i : i + h, j : j + w]
def resize_crop_to_fill(clip, target_size):
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
h, w = clip.size(-2), clip.size(-1)
th, tw = target_size[0], target_size[1]
rh, rw = th / h, tw / w
if rh > rw:
sh, sw = th, round(w * rh)
clip = resize(clip, (sh, sw), "bilinear")
i = 0
j = int(round(sw - tw) / 2.0)
else:
sh, sw = round(h * rw), tw
clip = resize(clip, (sh, sw), "bilinear")
i = int(round(sh - th) / 2.0)
j = 0
assert i + th <= clip.size(-2) and j + tw <= clip.size(-1)
return crop(clip, i, j, th, tw)
\ No newline at end of file
'''
# --------------------------------------------------------------------------------
# Color fixed script from Li Yi (https://github.com/pkuliyi2015/sd-webui-stablesr/blob/master/srmodule/colorfix.py)
# --------------------------------------------------------------------------------
'''
import torch
from PIL import Image
from torch import Tensor
from torch.nn import functional as F
from einops import rearrange
from torchvision.transforms import ToTensor, ToPILImage
def adain_color_fix(target: Image, source: Image):
# torch.Size([3, 16, 256, 256])
# Apply adaptive instance normalization
target = rearrange(target, "C T H W -> T C H W")
source = rearrange(source, "C T H W -> T C H W")
result_tensor = adaptive_instance_normalization(target, source)
result_tensor = rearrange(result_tensor, "T C H W -> C T H W").clamp_(0.0, 1.0)
# Convert tensor back to image
# to_image = ToPILImage()
# result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
return result_tensor
def wavelet_color_fix(target: Image, source: Image):
# Convert images to tensors
to_tensor = ToTensor()
target_tensor = to_tensor(target).unsqueeze(0)
source_tensor = to_tensor(source).unsqueeze(0)
# Apply wavelet reconstruction
result_tensor = wavelet_reconstruction(target_tensor, source_tensor)
# Convert tensor back to image
to_image = ToPILImage()
result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
return result_image
def calc_mean_std(feat: Tensor, eps=1e-5):
"""Calculate mean and std for adaptive_instance_normalization.
Args:
feat (Tensor): 4D tensor.
eps (float): A small value added to the variance to avoid
divide-by-zero. Default: 1e-5.
"""
size = feat.size()
assert len(size) == 4, 'The input feature should be 4D tensor.'
b, c = size[:2]
feat_var = feat.reshape(b, c, -1).var(dim=2) + eps
feat_std = feat_var.sqrt().reshape(b, c, 1, 1)
feat_mean = feat.reshape(b, c, -1).mean(dim=2).reshape(b, c, 1, 1)
return feat_mean, feat_std
def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor):
"""Adaptive instance normalization.
Adjust the reference features to have the similar color and illuminations
as those in the degradate features.
Args:
content_feat (Tensor): The reference feature.
style_feat (Tensor): The degradate features.
"""
size = content_feat.size()
style_mean, style_std = calc_mean_std(style_feat)
content_mean, content_std = calc_mean_std(content_feat)
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
def wavelet_blur(image: Tensor, radius: int):
"""
Apply wavelet blur to the input tensor.
"""
# input shape: (1, 3, H, W)
# convolution kernel
kernel_vals = [
[0.0625, 0.125, 0.0625],
[0.125, 0.25, 0.125],
[0.0625, 0.125, 0.0625],
]
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
# add channel dimensions to the kernel to make it a 4D tensor
kernel = kernel[None, None]
# repeat the kernel across all input channels
kernel = kernel.repeat(3, 1, 1, 1)
image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
# apply convolution
output = F.conv2d(image, kernel, groups=3, dilation=radius)
return output
def wavelet_decomposition(image: Tensor, levels=5):
"""
Apply wavelet decomposition to the input tensor.
This function only returns the low frequency & the high frequency.
"""
high_freq = torch.zeros_like(image)
for i in range(levels):
radius = 2 ** i
low_freq = wavelet_blur(image, radius)
high_freq += (image - low_freq)
image = low_freq
return high_freq, low_freq
def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
"""
Apply wavelet decomposition, so that the content will have the same color as the style.
"""
# calculate the wavelet decomposition of the content feature
content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
del content_low_freq
# calculate the wavelet decomposition of the style feature
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
del style_high_freq
# reconstruct the content feature with the style's high frequency
return content_high_freq + style_low_freq
from .dit import *
from .latte import *
from .pixart import *
from .stdit import *
from .text_encoder import *
from .vae import *
from .dit import DiT, DiT_XL_2, DiT_XL_2x2
# Modified from Meta DiT
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# DiT: https://github.com/facebookresearch/DiT/tree/main
# GLIDE: https://github.com/openai/glide-text2im
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
# --------------------------------------------------------
import numpy as np
import torch
import torch.nn as nn
import torch.utils.checkpoint
from einops import rearrange
from timm.models.vision_transformer import Mlp
from opensora.acceleration.checkpoint import auto_grad_checkpoint
from opensora.models.layers.blocks import (
Attention,
CaptionEmbedder,
FinalLayer,
LabelEmbedder,
PatchEmbed3D,
TimestepEmbedder,
approx_gelu,
get_1d_sincos_pos_embed,
get_2d_sincos_pos_embed,
get_layernorm,
modulate,
)
from opensora.registry import MODELS
from opensora.utils.ckpt_utils import load_checkpoint
class DiTBlock(nn.Module):
"""
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
"""
def __init__(
self,
hidden_size,
num_heads,
mlp_ratio=4.0,
enable_flashattn=False,
enable_layernorm_kernel=False,
):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.enable_flashattn = enable_flashattn
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.norm1 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)
self.attn = Attention(
hidden_size,
num_heads=num_heads,
qkv_bias=True,
enable_flashattn=enable_flashattn,
)
self.norm2 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
def forward(self, x, c):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1, x, shift_msa, scale_msa))
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2, x, shift_mlp, scale_mlp))
return x
@MODELS.register_module()
class DiT(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(
self,
input_size=(16, 32, 32),
in_channels=4,
patch_size=(1, 2, 2),
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4.0,
class_dropout_prob=0.1,
learn_sigma=True,
condition="text",
no_temporal_pos_emb=False,
caption_channels=512,
model_max_length=77,
dtype=torch.float32,
enable_flashattn=False,
enable_layernorm_kernel=False,
enable_sequence_parallelism=False,
):
super().__init__()
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels * 2 if learn_sigma else in_channels
self.hidden_size = hidden_size
self.patch_size = patch_size
self.input_size = input_size
num_patches = np.prod([input_size[i] // patch_size[i] for i in range(3)])
self.num_patches = num_patches
self.num_temporal = input_size[0] // patch_size[0]
self.num_spatial = num_patches // self.num_temporal
self.num_heads = num_heads
self.dtype = dtype
self.use_text_encoder = not condition.startswith("label")
if enable_flashattn:
assert dtype in [
torch.float16,
torch.bfloat16,
], f"Flash attention only supports float16 and bfloat16, but got {self.dtype}"
self.no_temporal_pos_emb = no_temporal_pos_emb
self.mlp_ratio = mlp_ratio
self.depth = depth
assert enable_sequence_parallelism is False, "Sequence parallelism is not supported in DiT"
self.register_buffer("pos_embed_spatial", self.get_spatial_pos_embed())
self.register_buffer("pos_embed_temporal", self.get_temporal_pos_embed())
self.x_embedder = PatchEmbed3D(patch_size, in_channels, embed_dim=hidden_size)
if not self.use_text_encoder:
num_classes = int(condition.split("_")[-1])
self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
else:
self.y_embedder = CaptionEmbedder(
in_channels=caption_channels,
hidden_size=hidden_size,
uncond_prob=class_dropout_prob,
act_layer=approx_gelu,
token_num=1, # pooled token
)
self.t_embedder = TimestepEmbedder(hidden_size)
self.blocks = nn.ModuleList(
[
DiTBlock(
hidden_size,
num_heads,
mlp_ratio=mlp_ratio,
enable_flashattn=enable_flashattn,
enable_layernorm_kernel=enable_layernorm_kernel,
)
for _ in range(depth)
]
)
self.final_layer = FinalLayer(hidden_size, np.prod(self.patch_size), self.out_channels)
self.initialize_weights()
self.enable_flashattn = enable_flashattn
self.enable_layernorm_kernel = enable_layernorm_kernel
def get_spatial_pos_embed(self):
pos_embed = get_2d_sincos_pos_embed(
self.hidden_size,
self.input_size[1] // self.patch_size[1],
)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)
return pos_embed
def get_temporal_pos_embed(self):
pos_embed = get_1d_sincos_pos_embed(
self.hidden_size,
self.input_size[0] // self.patch_size[0],
)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)
return pos_embed
def unpatchify(self, x):
c = self.out_channels
t, h, w = [self.input_size[i] // self.patch_size[i] for i in range(3)]
pt, ph, pw = self.patch_size
x = x.reshape(shape=(x.shape[0], t, h, w, pt, ph, pw, c))
x = rearrange(x, "n t h w r p q c -> n c t r h p w q")
imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
return imgs
def forward(self, x, t, y):
"""
Forward pass of DiT.
x: (B, C, T, H, W) tensor of inputs
t: (B,) tensor of diffusion timesteps
y: list of text
"""
# origin inputs should be float32, cast to specified dtype
x = x.to(self.dtype)
if self.use_text_encoder:
y = y.to(self.dtype)
# embedding
x = self.x_embedder(x) # (B, N, D)
x = rearrange(x, "b (t s) d -> b t s d", t=self.num_temporal, s=self.num_spatial)
x = x + self.pos_embed_spatial
if not self.no_temporal_pos_emb:
x = rearrange(x, "b t s d -> b s t d")
x = x + self.pos_embed_temporal
x = rearrange(x, "b s t d -> b (t s) d")
else:
x = rearrange(x, "b t s d -> b (t s) d")
t = self.t_embedder(t, dtype=x.dtype) # (N, D)
y = self.y_embedder(y, self.training) # (N, D)
if self.use_text_encoder:
y = y.squeeze(1).squeeze(1)
condition = t + y
# blocks
for _, block in enumerate(self.blocks):
c = condition
x = auto_grad_checkpoint(block, x, c) # (B, N, D)
# final process
x = self.final_layer(x, condition) # (B, N, num_patches * out_channels)
x = self.unpatchify(x) # (B, out_channels, T, H, W)
# cast to float32 for better accuracy
x = x.to(torch.float32)
return x
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
if module.weight.requires_grad_:
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.x_embedder.proj.bias, 0)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# Zero-out adaLN modulation layers in DiT blocks:
for block in self.blocks:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
# Zero-out text embedding layers:
if self.use_text_encoder:
nn.init.normal_(self.y_embedder.y_proj.fc1.weight, std=0.02)
nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02)
@MODELS.register_module("DiT-XL/2")
def DiT_XL_2(from_pretrained=None, **kwargs):
model = DiT(
depth=28,
hidden_size=1152,
patch_size=(1, 2, 2),
num_heads=16,
**kwargs,
)
if from_pretrained is not None:
load_checkpoint(model, from_pretrained)
return model
@MODELS.register_module("DiT-XL/2x2")
def DiT_XL_2x2(from_pretrained=None, **kwargs):
model = DiT(
depth=28,
hidden_size=1152,
patch_size=(2, 2, 2),
num_heads=16,
**kwargs,
)
if from_pretrained is not None:
load_checkpoint(model, from_pretrained)
return model
from .latte import Latte, Latte_XL_2, Latte_XL_2x2
# Copyright 2024 Vchitect/Latte
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.# Modified from Latte
#
#
# This file is mofied from https://github.com/Vchitect/Latte/blob/main/models/latte.py
#
# With references to:
# Latte: https://github.com/Vchitect/Latte
# DiT: https://github.com/facebookresearch/DiT/tree/main
import torch
from einops import rearrange, repeat
from opensora.acceleration.checkpoint import auto_grad_checkpoint
from opensora.models.dit import DiT
from opensora.registry import MODELS
from opensora.utils.ckpt_utils import load_checkpoint
@MODELS.register_module()
class Latte(DiT):
def forward(self, x, t, y):
"""
Forward pass of DiT.
x: (B, C, T, H, W) tensor of inputs
t: (B,) tensor of diffusion timesteps
y: list of text
"""
# origin inputs should be float32, cast to specified dtype
x = x.to(self.dtype)
# embedding
x = self.x_embedder(x) # (B, N, D)
x = rearrange(x, "b (t s) d -> b t s d", t=self.num_temporal, s=self.num_spatial)
x = x + self.pos_embed_spatial
x = rearrange(x, "b t s d -> b (t s) d")
t = self.t_embedder(t, dtype=x.dtype) # (N, D)
y = self.y_embedder(y, self.training) # (N, D)
if self.use_text_encoder:
y = y.squeeze(1).squeeze(1)
condition = t + y
condition_spatial = repeat(condition, "b d -> (b t) d", t=self.num_temporal)
condition_temporal = repeat(condition, "b d -> (b s) d", s=self.num_spatial)
# blocks
for i, block in enumerate(self.blocks):
if i % 2 == 0:
# spatial
x = rearrange(x, "b (t s) d -> (b t) s d", t=self.num_temporal, s=self.num_spatial)
c = condition_spatial
else:
# temporal
x = rearrange(x, "b (t s) d -> (b s) t d", t=self.num_temporal, s=self.num_spatial)
c = condition_temporal
if i == 1:
x = x + self.pos_embed_temporal
x = auto_grad_checkpoint(block, x, c) # (B, N, D)
if i % 2 == 0:
x = rearrange(x, "(b t) s d -> b (t s) d", t=self.num_temporal, s=self.num_spatial)
else:
x = rearrange(x, "(b s) t d -> b (t s) d", t=self.num_temporal, s=self.num_spatial)
# final process
x = self.final_layer(x, condition) # (B, N, num_patches * out_channels)
x = self.unpatchify(x) # (B, out_channels, T, H, W)
# cast to float32 for better accuracy
x = x.to(torch.float32)
return x
@MODELS.register_module("Latte-XL/2")
def Latte_XL_2(from_pretrained=None, **kwargs):
model = Latte(
depth=28,
hidden_size=1152,
patch_size=(1, 2, 2),
num_heads=16,
**kwargs,
)
if from_pretrained is not None:
load_checkpoint(model, from_pretrained)
return model
@MODELS.register_module("Latte-XL/2x2")
def Latte_XL_2x2(from_pretrained=None, **kwargs):
model = Latte(
depth=28,
hidden_size=1152,
patch_size=(2, 2, 2),
num_heads=16,
**kwargs,
)
if from_pretrained is not None:
load_checkpoint(model, from_pretrained)
return model
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# PixArt: https://github.com/PixArt-alpha/PixArt-alpha
# Latte: https://github.com/Vchitect/Latte
# DiT: https://github.com/facebookresearch/DiT/tree/main
# GLIDE: https://github.com/openai/glide-text2im
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
# --------------------------------------------------------
import math
from typing import Any, Dict, List, Optional, Tuple, Union, KeysView
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
import xformers.ops
from einops import rearrange
from timm.models.vision_transformer import Mlp
from opensora.acceleration.communications import all_to_all, split_forward_gather_backward
from opensora.acceleration.parallel_states import get_sequence_parallel_group
# import ipdb
approx_gelu = lambda: nn.GELU(approximate="tanh")
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
#ipdb.set_trace()
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
#ipdb.set_trace()
return self.weight * hidden_states.to(input_dtype)
def get_layernorm(hidden_size: torch.Tensor, eps: float, affine: bool, use_kernel: bool):
if use_kernel:
try:
from apex.normalization import FusedLayerNorm
return FusedLayerNorm(hidden_size, elementwise_affine=affine, eps=eps)
except ImportError:
raise RuntimeError("FusedLayerNorm not available. Please install apex.")
else:
return nn.LayerNorm(hidden_size, eps, elementwise_affine=affine)
def modulate(norm_func, x, shift, scale):
# Suppose x is (B, N, D), shift is (B, D), scale is (B, D)
dtype = x.dtype
x = norm_func(x.to(torch.float32)).to(dtype)
x = x * (scale.unsqueeze(1) + 1) + shift.unsqueeze(1)
x = x.to(dtype)
return x
def t2i_modulate(x, shift, scale):
return x * (1 + scale) + shift
# ===============================================
# General-purpose Layers
# ===============================================
class PatchEmbed3D(nn.Module):
"""Video to Patch Embedding.
Args:
patch_size (int): Patch token size. Default: (2,4,4).
in_chans (int): Number of input video channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(
self,
patch_size=(2, 4, 4),
in_chans=3,
embed_dim=96,
padding=None,
norm_layer=None,
flatten=True,
):
super().__init__()
self.patch_size = patch_size
self.flatten = flatten
self.in_chans = in_chans
self.embed_dim = embed_dim
self.padding = padding
if padding is not None:
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, padding=padding)
else:
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
"""Forward function."""
if self.padding is None:
# padding
_, _, D, H, W = x.size()
if W % self.patch_size[2] != 0:
x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))
if H % self.patch_size[1] != 0:
x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))
if D % self.patch_size[0] != 0:
x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0]))
x = self.proj(x) # (B C T H W)
if self.norm is not None:
D, Wh, Ww = x.size(2), x.size(3), x.size(4)
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCTHW -> BNC
return x
class Attention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = False,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
norm_layer: nn.Module = nn.LayerNorm,
enable_flashattn: bool = False,
) -> None:
super().__init__()
assert dim % num_heads == 0, "dim should be divisible by num_heads"
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.enable_flashattn = enable_flashattn
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x)
qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
if self.enable_flashattn: # here
qkv_permute_shape = (2, 0, 1, 3, 4)
else:
qkv_permute_shape = (2, 0, 3, 1, 4)
qkv = qkv.view(qkv_shape).permute(qkv_permute_shape)
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
if self.enable_flashattn:
from flash_attn import flash_attn_func
x = flash_attn_func(
q,
k,
v,
dropout_p=self.attn_drop.p if self.training else 0.0,
softmax_scale=self.scale,
)
else:
dtype = q.dtype
q = q * self.scale
attn = q @ k.transpose(-2, -1) # translate attn to float32
attn = attn.to(torch.float32)
attn = attn.softmax(dim=-1)
attn = attn.to(dtype) # cast back attn to original dtype
attn = self.attn_drop(attn)
x = attn @ v
x_output_shape = (B, N, C)
if not self.enable_flashattn:
x = x.transpose(1, 2)
x = x.reshape(x_output_shape)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Attention_QKNorm_RoPE(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = False,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
norm_layer: nn.Module = LlamaRMSNorm,
enable_flashattn: bool = False,
rope=None,
) -> None:
super().__init__()
assert dim % num_heads == 0, "dim should be divisible by num_heads"
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.enable_flashattn = enable_flashattn
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.rotary_emb = rope
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x)
qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
if self.enable_flashattn:
qkv_permute_shape = (2, 0, 1, 3, 4)
else:
qkv_permute_shape = (2, 0, 3, 1, 4)
qkv = qkv.view(qkv_shape).permute(qkv_permute_shape)
q, k, v = qkv.unbind(0)
#ipdb.set_trace()
if self.rotary_emb is not None:
q = self.rotary_emb(q)
k = self.rotary_emb(k)
#ipdb.set_trace()
q, k = self.q_norm(q), self.k_norm(k)
#ipdb.set_trace()
if self.enable_flashattn:
from flash_attn import flash_attn_func
x = flash_attn_func(
q,
k,
v,
dropout_p=self.attn_drop.p if self.training else 0.0,
softmax_scale=self.scale,
)
else:
dtype = q.dtype
q = q * self.scale
attn = q @ k.transpose(-2, -1) # translate attn to float32
attn = attn.to(torch.float32)
attn = attn.softmax(dim=-1)
attn = attn.to(dtype) # cast back attn to original dtype
attn = self.attn_drop(attn)
x = attn @ v
x_output_shape = (B, N, C)
if not self.enable_flashattn:
x = x.transpose(1, 2)
x = x.reshape(x_output_shape)
x = self.proj(x)
x = self.proj_drop(x)
return x
class MaskedSelfAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = False,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
norm_layer: nn.Module = LlamaRMSNorm,
enable_flashattn: bool = False,
rope=None,
) -> None:
super().__init__()
assert dim % num_heads == 0, "dim should be divisible by num_heads"
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.enable_flashattn = enable_flashattn
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.rotary_emb = rope
def forward(self, x, mask):
B, N, C = x.shape
qkv = self.qkv(x)
qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
qkv_permute_shape = (2, 0, 3, 1, 4)
qkv = qkv.view(qkv_shape).permute(qkv_permute_shape)
q, k, v = qkv.unbind(0) # B H N C
#ipdb.set_trace()
if self.rotary_emb is not None:
q = self.rotary_emb(q)
k = self.rotary_emb(k)
#ipdb.set_trace()
q, k = self.q_norm(q), self.k_norm(k)
#ipdb.set_trace()
mask = mask.unsqueeze(1).unsqueeze(1).repeat(1, self.num_heads, 1, 1).to(torch.float32) # B H 1 N
dtype = q.dtype
q = q * self.scale
attn = q @ k.transpose(-2, -1) # translate attn to float32
attn = attn.to(torch.float32)
attn = attn.masked_fill(mask == 0, -1e9)
attn = attn.softmax(dim=-1)
attn = attn.to(dtype) # cast back attn to original dtype
attn = self.attn_drop(attn)
x = attn @ v
x_output_shape = (B, N, C)
x = x.transpose(1, 2)
x = x.reshape(x_output_shape)
x = self.proj(x)
x = self.proj_drop(x)
return x
class SeqParallelAttention(Attention):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = False,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
norm_layer: nn.Module = nn.LayerNorm,
enable_flashattn: bool = False,
) -> None:
super().__init__(
dim=dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
attn_drop=attn_drop,
proj_drop=proj_drop,
norm_layer=norm_layer,
enable_flashattn=enable_flashattn,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape # for sequence parallel here, the N is a local sequence length
qkv = self.qkv(x)
qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
qkv = qkv.view(qkv_shape)
sp_group = get_sequence_parallel_group()
# apply all_to_all to gather sequence and split attention heads
# [B, SUB_N, 3, NUM_HEAD, HEAD_DIM] -> [B, N, 3, NUM_HEAD_PER_DEVICE, HEAD_DIM]
qkv = all_to_all(qkv, sp_group, scatter_dim=3, gather_dim=1)
if self.enable_flashattn:
qkv_permute_shape = (2, 0, 1, 3, 4) # [3, B, N, NUM_HEAD_PER_DEVICE, HEAD_DIM]
else:
qkv_permute_shape = (2, 0, 3, 1, 4) # [3, B, NUM_HEAD_PER_DEVICE, N, HEAD_DIM]
qkv = qkv.permute(qkv_permute_shape)
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
if self.enable_flashattn:
from flash_attn import flash_attn_func
x = flash_attn_func(
q,
k,
v,
dropout_p=self.attn_drop.p if self.training else 0.0,
softmax_scale=self.scale,
)
else:
dtype = q.dtype
q = q * self.scale
attn = q @ k.transpose(-2, -1) # translate attn to float32
attn = attn.to(torch.float32)
attn = attn.softmax(dim=-1)
attn = attn.to(dtype) # cast back attn to original dtype
attn = self.attn_drop(attn)
x = attn @ v
if not self.enable_flashattn:
x = x.transpose(1, 2)
# apply all to all to gather back attention heads and split sequence
# [B, N, NUM_HEAD_PER_DEVICE, HEAD_DIM] -> [B, SUB_N, NUM_HEAD, HEAD_DIM]
x = all_to_all(x, sp_group, scatter_dim=1, gather_dim=2)
# reshape outputs back to [B, N, C]
x_output_shape = (B, N, C)
x = x.reshape(x_output_shape)
x = self.proj(x)
x = self.proj_drop(x)
return x
class MultiHeadCrossAttention(nn.Module):
def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0):
super(MultiHeadCrossAttention, self).__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.q_linear = nn.Linear(d_model, d_model)
self.kv_linear = nn.Linear(d_model, d_model * 2)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(d_model, d_model)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, cond, mask=None):
# query/value: img tokens; key: condition; mask: if padding tokens
B, N, C = x.shape
q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
k, v = kv.unbind(2)
#ipdb.set_trace()
attn_bias = None
if mask is not None:
attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask)
x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
#ipdb.set_trace()
x = x.view(B, -1, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class MaskedMultiHeadCrossAttention(nn.Module):
def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0):
super(MaskedMultiHeadCrossAttention, self).__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.q_linear = nn.Linear(d_model, d_model)
self.kv_linear = nn.Linear(d_model, d_model * 2)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(d_model, d_model)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, cond, mask=None):
# query/value: img tokens; key: condition; mask: if padding tokens
B, S, C = x.shape
L = cond.shape[1]
q = self.q_linear(x).view(B, S, self.num_heads, self.head_dim)
kv = self.kv_linear(cond).view(B, L, 2, self.num_heads, self.head_dim)
k, v = kv.unbind(2)
#ipdb.set_trace()
attn_bias = None
if mask is not None:
attn_bias = mask.unsqueeze(1).unsqueeze(1).repeat(1, self.num_heads, S, 1).to(q.dtype) # B H S L
exp = -1e9
attn_bias[attn_bias==0] = exp
attn_bias[attn_bias==1] = 0
x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
#ipdb.set_trace()
x = x.view(B, -1, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class MaskedMeanMultiHeadCrossAttention(nn.Module):
def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0):
super(MaskedMeanMultiHeadCrossAttention, self).__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.q_linear = nn.Linear(d_model, d_model)
self.kv_linear = nn.Linear(d_model, d_model * 2)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(d_model, d_model)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, cond, mask=None):
# query/value: img tokens; key: condition; mask: if padding tokens
B, T, S, C = x.shape
L = cond.shape[2]
x = rearrange(x, "B T S C -> B (T S) C")
N = x.shape[1]
cond = torch.mean(cond, dim=1) # B L C
mask = mask[:, 0, :] # B L
q = self.q_linear(x).view(B, N, self.num_heads, self.head_dim)
kv = self.kv_linear(cond).view(B, L, 2, self.num_heads, self.head_dim)
k, v = kv.unbind(2)
#ipdb.set_trace()
attn_bias = None
if mask is not None:
attn_bias = mask.unsqueeze(1).unsqueeze(1).repeat(1, self.num_heads, N, 1).to(q.dtype) # B H N L
exp = -1e9
attn_bias[attn_bias==0] = exp
attn_bias[attn_bias==1] = 0
x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
#ipdb.set_trace()
x = rearrange(x, "B (T S) H C -> (B T) S (H C)", T=T, S=S)
x = self.proj(x)
x = self.proj_drop(x)
x = rearrange(x, "(B T) S C -> B T S C", B=B, T=T)
return x
class LongShortMultiHeadCrossAttention(nn.Module):
def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0):
super(LongShortMultiHeadCrossAttention, self).__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.q_linear = nn.Linear(d_model, d_model)
self.kv_linear = nn.Linear(d_model, d_model * 2)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(d_model, d_model)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, cond, mask=None):
# query/value: img tokens; key: condition; mask: if padding tokens
B, N, C = x.shape
M = cond.shape[1]
q = self.q_linear(x).view(B, N, self.num_heads, self.head_dim)
kv = self.kv_linear(cond).view(B, M, 2, self.num_heads, self.head_dim)
k, v = kv.unbind(2)
attn_bias = None
if mask is not None:
attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask)
x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
x = x.view(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class MultiHeadV2TCrossAttention(nn.Module):
def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0):
super(MultiHeadV2TCrossAttention, self).__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.q_linear = nn.Linear(d_model, d_model)
self.kv_linear = nn.Linear(d_model, d_model * 2)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(d_model, d_model)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, cond, mask=None):
# query/value: condition; key: img tokens; mask: if padding tokens
B, N, C = cond.shape
q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
k, v = kv.unbind(2)
#ipdb.set_trace()
attn_bias = None
if mask is not None:
attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens(mask, [N] * B)
x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
#ipdb.set_trace()
x = x.view(B, -1, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class MultiHeadT2VCrossAttention(nn.Module):
def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0):
super(MultiHeadT2VCrossAttention, self).__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.q_linear = nn.Linear(d_model, d_model)
self.kv_linear = nn.Linear(d_model, d_model * 2)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(d_model, d_model)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, cond, mask=None):
# query/value: img tokens; key: condition; mask: if padding tokens
#ipdb.set_trace()
B, T, N, C = x.shape
x = rearrange(x, 'B T N C -> (B T) N C')
q = self.q_linear(x)
q = rearrange(q, '(B T) N C -> B T N C', T=T)
q = q.view(1, -1, self.num_heads, self.head_dim) # 1(B T N) H C
kv = self.kv_linear(cond)
kv = kv.view(1, -1, 2, self.num_heads, self.head_dim) # 1 N 2 H C
k, v = kv.unbind(2)
#ipdb.set_trace()
attn_bias = None
if mask is not None:
#mask = [m for m in mask for _ in range(T)]
attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * (B*T), mask)
x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
#ipdb.set_trace()
x = x.view(B, T, N, C)
x = rearrange(x, 'B T N C -> (B T) N C')
x = self.proj(x)
x = self.proj_drop(x)
x = rearrange(x, '(B T) N C -> B T N C', T=T)
return x
class FormerMultiHeadV2TCrossAttention(nn.Module):
def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0):
super(FormerMultiHeadV2TCrossAttention, self).__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.q_linear = nn.Linear(d_model, d_model)
self.kv_linear = nn.Linear(d_model, d_model * 2)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(d_model, d_model)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, cond, mask=None):
# x: text tokens; cond: img tokens; mask: if padding tokens
#ipdb.set_trace()
_, N, C = x.shape # 1 N C
B, T, _, _ = cond.shape
cond = rearrange(cond, 'B T N C -> (B T) N C')
q = self.q_linear(x)
q = q.view(1, -1, self.num_heads, self.head_dim) # 1 N H C
kv = self.kv_linear(cond)
kv = rearrange(kv, '(B T) N C -> B T N C', B=B)
M = kv.shape[2] # M = H * W
former_frame_index = torch.arange(T) - 1
former_frame_index[0] = 0
#ipdb.set_trace()
former_kv = kv[:, former_frame_index]
former_kv = former_kv.view(1, -1, 2, self.num_heads, self.head_dim) # 1(B T N) 2 H C
former_k, former_v = former_kv.unbind(2)
#ipdb.set_trace()
attn_bias = None
if mask is not None:
#mask = [m for m in mask for _ in range(T)]
attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens(mask, [M] * (B*T))
x = xformers.ops.memory_efficient_attention(q, former_k, former_v, p=self.attn_drop.p, attn_bias=attn_bias)
#ipdb.set_trace()
x = x.view(1, -1, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class LatterMultiHeadV2TCrossAttention(nn.Module):
def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0):
super(LatterMultiHeadV2TCrossAttention, self).__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.q_linear = nn.Linear(d_model, d_model)
self.kv_linear = nn.Linear(d_model, d_model * 2)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(d_model, d_model)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, cond, mask=None):
# x: text tokens; cond: img tokens; mask: if padding tokens
#ipdb.set_trace()
_, N, C = x.shape # 1 N C
B, T, _, _ = cond.shape
cond = rearrange(cond, 'B T N C -> (B T) N C')
q = self.q_linear(x)
q = q.view(1, -1, self.num_heads, self.head_dim) # 1 N H C
kv = self.kv_linear(cond)
kv = rearrange(kv, '(B T) N C -> B T N C', T=T)
M = kv.shape[2] # M = H * W
latter_frame_index = torch.arange(T) + 1
latter_frame_index[-1] = T - 1
#ipdb.set_trace()
latter_kv = kv[:, latter_frame_index]
latter_kv = latter_kv.view(1, -1, 2, self.num_heads, self.head_dim) # 1(B T N) 2 H C
latter_k, latter_v = latter_kv.unbind(2)
#ipdb.set_trace()
attn_bias = None
if mask is not None:
# mask = [m for m in mask for _ in range(T)]
attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens(mask, [M] * (B*T))
x = xformers.ops.memory_efficient_attention(q, latter_k, latter_v, p=self.attn_drop.p, attn_bias=attn_bias)
#ipdb.set_trace()
x = x.view(1, -1, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class SeqParallelMultiHeadCrossAttention(MultiHeadCrossAttention):
def __init__(
self,
d_model,
num_heads,
attn_drop=0.0,
proj_drop=0.0,
):
super().__init__(d_model=d_model, num_heads=num_heads, attn_drop=attn_drop, proj_drop=proj_drop)
def forward(self, x, cond, mask=None):
# query/value: img tokens; key: condition; mask: if padding tokens
sp_group = get_sequence_parallel_group()
sp_size = dist.get_world_size(sp_group)
B, SUB_N, C = x.shape
N = SUB_N * sp_size
# shape:
# q, k, v: [B, SUB_N, NUM_HEADS, HEAD_DIM]
q = self.q_linear(x).view(B, -1, self.num_heads, self.head_dim)
kv = self.kv_linear(cond).view(B, -1, 2, self.num_heads, self.head_dim)
k, v = kv.unbind(2)
# apply all_to_all to gather sequence and split attention heads
q = all_to_all(q, sp_group, scatter_dim=2, gather_dim=1)
k = split_forward_gather_backward(k, get_sequence_parallel_group(), dim=2, grad_scale="down")
v = split_forward_gather_backward(v, get_sequence_parallel_group(), dim=2, grad_scale="down")
q = q.view(1, -1, self.num_heads // sp_size, self.head_dim)
k = k.view(1, -1, self.num_heads // sp_size, self.head_dim)
v = v.view(1, -1, self.num_heads // sp_size, self.head_dim)
# compute attention
attn_bias = None
if mask is not None:
attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask)
x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
# apply all to all to gather back attention heads and scatter sequence
x = x.view(B, -1, self.num_heads // sp_size, self.head_dim)
x = all_to_all(x, sp_group, scatter_dim=1, gather_dim=2)
# apply output projection
x = x.view(B, -1, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class FinalLayer(nn.Module):
"""
The final layer of DiT.
"""
def __init__(self, hidden_size, num_patch, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, num_patch * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
x = modulate(self.norm_final, x, shift, scale)
x = self.linear(x)
return x
class T2IFinalLayer(nn.Module):
"""
The final layer of PixArt.
"""
def __init__(self, hidden_size, num_patch, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, num_patch * out_channels, bias=True)
self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size**0.5)
self.out_channels = out_channels
def forward(self, x, t):
shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1)
x = t2i_modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
# ==================
# Frequency Layers
# ==================
class SpatialFrequencyBlcok(nn.Module):
def __init__(self, dim):
super(SpatialFrequencyBlcok, self).__init__()
self.act_layer = nn.GELU(approximate="tanh")
# Process low-frequency
self.low_freq_layer1 = nn.Linear(in_features=dim, out_features=2 * dim)
self.low_freq_layer2 = nn.Linear(in_features=2 * dim, out_features=dim)
# Process high-frequency
self.high_freq_layer1 = nn.Linear(in_features=dim, out_features=2 * dim)
self.high_freq_layer2 = nn.Linear(in_features=2 * dim, out_features=dim)
def forward(self, x, use_cfg=True):
if use_cfg:
# x shape: torch.Size([4, 4096, 1152])
high_1, low_1, high_2, low_2 = torch.chunk(x, 4, dim=0)
highfreq = torch.cat((high_1, high_2), dim=0) # torch.Size([2, 4096, 1152])
lowfreq = torch.cat((low_1, low_2), dim=0) # torch.Size([2, 4096, 1152])
# extention
highfreq, hf_info = self.high_freq_layer1(highfreq).chunk(2, dim=-1)
lowfreq, lf_info = self.low_freq_layer1(lowfreq).chunk(2, dim=-1)
# fusion
high_1, high_2 = self.high_freq_layer2(torch.cat((highfreq, lf_info), dim=-1)).chunk(2, dim=0)
low_1, low_2 = self.low_freq_layer2(torch.cat((lowfreq, hf_info), dim=-1)).chunk(2, dim=0)
out = torch.cat((high_1, low_1, high_2, low_2), dim=0)
else:
highfreq, lowfreq = torch.chunk(x, 2, dim=0)
# extention
highfreq, hf_info = self.high_freq_layer1(highfreq).chunk(2, dim=-1)
lowfreq, lf_info = self.low_freq_layer1(lowfreq).chunk(2, dim=-1)
# fusion
highfreq = self.high_freq_layer2(torch.cat((highfreq, lf_info), dim=-1))
lowfreq = self.low_freq_layer2(torch.cat((lowfreq, hf_info), dim=-1))
out = torch.cat((highfreq, lowfreq), dim=0)
return out
class TemporalFrequencyBlock(nn.Module):
def __init__(self, dim, num_heads, qkv_bias, attn_drop=0.0, proj_drop=0.0):
super(TemporalFrequencyBlock, self).__init__()
assert dim % num_heads == 0, "dim must be divisible by num_heads"
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.qkv = nn.Linear(dim * 2, dim * 3, bias=qkv_bias)
# self.qkv2 = nn.Linear(dim, dim * 3, bias=qkv_bias)
# self.reduction = nn.Linear(dim * 6, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, cond):
# query/value: img tokens; key: condition; mask: if padding tokens
B, N, C = x.shape
# qkv1 = self.qkv1(x)
# qkv2 = self.qkv2(cond)
qkv = torch.cat((x, cond), dim=-1)
qkv = self.qkv(qkv)
qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
qkv_permute_shape = (2, 0, 3, 1, 4)
qkv = qkv.view(qkv_shape).permute(qkv_permute_shape)
q, k, v = qkv.unbind(0)
dtype = q.dtype
q = q * self.scale
attn = q @ k.transpose(-2, -1) # translate attn to float32
attn = attn.to(torch.float32)
attn = attn.softmax(dim=-1)
attn = attn.to(dtype) # cast back attn to original dtype
attn = self.attn_drop(attn)
x = attn @ v
x_output_shape = (B, N, C)
x = x.transpose(1, 2)
x = x.reshape(x_output_shape)
x = self.proj(x)
x = self.proj_drop(x)
return x
def zero_module(module):
for p in module.parameters():
nn.init.zeros_(p)
return module
class Encoder_3D(nn.Module):
"""
Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
[11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
(activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
model) to encode image-space conditions ... into feature maps ..."
"""
def __init__(
self,
conditioning_embedding_channels: int,
# conditioning_channels: int = 3,
block_out_channels: Tuple[int] = (16, 32, 96, 256),
):
super().__init__()
# self.conv_in = nn.Conv3d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
self.blocks = nn.ModuleList([])
for i in range(len(block_out_channels) - 1):
channel_in = block_out_channels[i]
channel_out = block_out_channels[i + 1]
self.blocks.append(nn.Conv3d(channel_in, channel_in, kernel_size=(3, 3, 3), padding=1, stride=1))
self.blocks.append(nn.Conv3d(channel_in, channel_out, kernel_size=(3, 3, 3), padding=1, stride=(1, 2, 2)))
self.conv_out = zero_module(
nn.Conv3d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=(3, 3, 3), padding=1, stride=1)
)
def forward(self, embedding):
# embedding = self.conv_in(conditioning)
# embedding = F.silu(embedding)
for block in self.blocks:
embedding = block(embedding)
embedding = F.silu(embedding)
embedding = self.conv_out(embedding)
return embedding
# ===============================================
# Embedding Layers for Timesteps and Class Labels
# ===============================================
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half)
freqs = freqs.to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t, dtype):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
if t_freq.dtype != dtype:
t_freq = t_freq.to(dtype)
t_emb = self.mlp(t_freq)
return t_emb
class LabelEmbedder(nn.Module):
"""
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
"""
def __init__(self, num_classes, hidden_size, dropout_prob):
super().__init__()
use_cfg_embedding = dropout_prob > 0
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
self.num_classes = num_classes
self.dropout_prob = dropout_prob
def token_drop(self, labels, force_drop_ids=None):
"""
Drops labels to enable classifier-free guidance.
"""
if force_drop_ids is None:
drop_ids = torch.rand(labels.shape[0]).cuda() < self.dropout_prob
else:
drop_ids = force_drop_ids == 1
labels = torch.where(drop_ids, self.num_classes, labels)
return labels
def forward(self, labels, train, force_drop_ids=None):
use_dropout = self.dropout_prob > 0
if (train and use_dropout) or (force_drop_ids is not None):
labels = self.token_drop(labels, force_drop_ids)
return self.embedding_table(labels)
class SizeEmbedder(TimestepEmbedder):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__(hidden_size=hidden_size, frequency_embedding_size=frequency_embedding_size)
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
self.outdim = hidden_size
def forward(self, s, bs):
if s.ndim == 1:
s = s[:, None]
assert s.ndim == 2
if s.shape[0] != bs:
s = s.repeat(bs // s.shape[0], 1)
assert s.shape[0] == bs
b, dims = s.shape[0], s.shape[1]
s = rearrange(s, "b d -> (b d)")
s_freq = self.timestep_embedding(s, self.frequency_embedding_size).to(self.dtype)
s_emb = self.mlp(s_freq)
s_emb = rearrange(s_emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim)
return s_emb
@property
def dtype(self):
return next(self.parameters()).dtype
class CaptionEmbedder(nn.Module):
"""
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
"""
def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate="tanh"), token_num=120):
super().__init__()
self.y_proj = Mlp(
in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer, drop=0
)
self.register_buffer("y_embedding", nn.Parameter(torch.randn(token_num, in_channels) / in_channels**0.5))
self.uncond_prob = uncond_prob
def token_drop(self, caption, force_drop_ids=None):
"""
Drops labels to enable classifier-free guidance.
"""
if force_drop_ids is None:
drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob
else:
drop_ids = force_drop_ids == 1
caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption)
return caption
def forward(self, caption, train, force_drop_ids=None):
if train:
assert caption.shape[2:] == self.y_embedding.shape
use_dropout = self.uncond_prob > 0
if (train and use_dropout) or (force_drop_ids is not None):
caption = self.token_drop(caption, force_drop_ids)
caption = self.y_proj(caption)
return caption
# ===============================================
# Sine/Cosine Positional Embedding Functions
# ===============================================
# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, scale=1.0, base_size=None):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
if not isinstance(grid_size, tuple):
grid_size = (grid_size, grid_size)
grid_h = np.arange(grid_size[0], dtype=np.float32) / scale
grid_w = np.arange(grid_size[1], dtype=np.float32) / scale
if base_size is not None:
grid_h *= base_size / grid_size[0]
grid_w *= base_size / grid_size[1]
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed(embed_dim, length, scale=1.0):
pos = np.arange(0, length)[..., None] / scale
return get_1d_sincos_pos_embed_from_grid(embed_dim, pos)
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment