Commit 78bae405 authored by mashun1's avatar mashun1
Browse files

open_sora_inference

parents
Pipeline #826 canceled with stages
import torch.distributed as dist
_GLOBAL_PARALLEL_GROUPS = dict()
def set_data_parallel_group(group: dist.ProcessGroup):
_GLOBAL_PARALLEL_GROUPS["data"] = group
def get_data_parallel_group():
return _GLOBAL_PARALLEL_GROUPS.get("data", None)
def set_sequence_parallel_group(group: dist.ProcessGroup):
_GLOBAL_PARALLEL_GROUPS["sequence"] = group
def get_sequence_parallel_group():
return _GLOBAL_PARALLEL_GROUPS.get("sequence", None)
import random
from typing import Optional
import numpy as np
import torch
from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.cluster import ProcessGroupMesh
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
DP_AXIS, SP_AXIS = 0, 1
class ZeroSeqParallelPlugin(LowLevelZeroPlugin):
def __init__(
self,
sp_size: int = 1,
stage: int = 2,
precision: str = "fp16",
initial_scale: float = 2**32,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
max_norm: float = 0.0,
norm_type: float = 2.0,
reduce_bucket_size_in_m: int = 12,
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = True,
cpu_offload: bool = False,
master_weights: bool = True,
verbose: bool = False,
) -> None:
super().__init__(
stage=stage,
precision=precision,
initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale,
max_norm=max_norm,
norm_type=norm_type,
reduce_bucket_size_in_m=reduce_bucket_size_in_m,
communication_dtype=communication_dtype,
overlap_communication=overlap_communication,
cpu_offload=cpu_offload,
master_weights=master_weights,
verbose=verbose,
)
self.sp_size = sp_size
assert self.world_size % sp_size == 0, "world_size must be divisible by sp_size"
self.dp_size = self.world_size // sp_size
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.sp_size)
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
self.sp_group = self.pg_mesh.get_group_along_axis(SP_AXIS)
self.dp_rank = self.pg_mesh.coordinate(DP_AXIS)
self.sp_rank = self.pg_mesh.coordinate(SP_AXIS)
def __del__(self):
"""Destroy the prcess groups in ProcessGroupMesh"""
self.pg_mesh.destroy_mesh_process_groups()
def prepare_dataloader(
self,
dataset,
batch_size,
shuffle=False,
seed=1024,
drop_last=False,
pin_memory=False,
num_workers=0,
distributed_sampler_cls=None,
**kwargs,
):
_kwargs = kwargs.copy()
distributed_sampler_cls = distributed_sampler_cls or DistributedSampler
sampler = distributed_sampler_cls(dataset, num_replicas=self.dp_size, rank=self.dp_rank, shuffle=shuffle)
# Deterministic dataloader
def seed_worker(worker_id):
worker_seed = seed
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
random.seed(worker_seed)
return DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
worker_init_fn=seed_worker,
drop_last=drop_last,
pin_memory=pin_memory,
num_workers=num_workers,
**_kwargs,
)
import torch
import torch.nn as nn
class T5LayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
# half-precision inputs is done in fp32
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states
@staticmethod
def from_native_module(module, *args, **kwargs):
assert module.__class__.__name__ == "FusedRMSNorm", (
"Recovering T5LayerNorm requires the original layer to be apex's Fused RMS Norm."
"Apex's fused norm is automatically used by Hugging Face Transformers https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L265C5-L265C48"
)
layer_norm = T5LayerNorm(module.normalized_shape, eps=module.eps)
layer_norm.weight.data.copy_(module.weight.data)
layer_norm = layer_norm.to(module.weight.device)
return layer_norm
from colossalai.shardformer.modeling.jit import get_jit_fused_dropout_add_func
from colossalai.shardformer.modeling.t5 import get_jit_fused_T5_layer_ff_forward, get_T5_layer_self_attention_forward
from colossalai.shardformer.policies.base_policy import Policy, SubModuleReplacementDescription
class T5EncoderPolicy(Policy):
def config_sanity_check(self):
assert not self.shard_config.enable_tensor_parallelism
assert not self.shard_config.enable_flash_attention
def preprocess(self):
return self.model
def module_policy(self):
from transformers.models.t5.modeling_t5 import T5LayerFF, T5LayerSelfAttention, T5Stack
policy = {}
# check whether apex is installed
try:
from opensora.acceleration.shardformer.modeling.t5 import T5LayerNorm
# recover hf from fused rms norm to T5 norm which is faster
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="layer_norm",
target_module=T5LayerNorm,
),
policy=policy,
target_key=T5LayerFF,
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(suffix="layer_norm", target_module=T5LayerNorm),
policy=policy,
target_key=T5LayerSelfAttention,
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(suffix="final_layer_norm", target_module=T5LayerNorm),
policy=policy,
target_key=T5Stack,
)
except (ImportError, ModuleNotFoundError):
pass
# use jit operator
if self.shard_config.enable_jit_fused:
self.append_or_create_method_replacement(
description={
"forward": get_jit_fused_T5_layer_ff_forward(),
"dropout_add": get_jit_fused_dropout_add_func(),
},
policy=policy,
target_key=T5LayerFF,
)
self.append_or_create_method_replacement(
description={
"forward": get_T5_layer_self_attention_forward(),
"dropout_add": get_jit_fused_dropout_add_func(),
},
policy=policy,
target_key=T5LayerSelfAttention,
)
return policy
def postprocess(self):
return self.model
from .datasets import DatasetFromCSV, get_transforms_image, get_transforms_video
from .utils import prepare_dataloader, save_sample
import csv
import os
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader
from . import video_transforms
from .utils import center_crop_arr
def get_transforms_video(resolution=256):
transform_video = transforms.Compose(
[
video_transforms.ToTensorVideo(), # TCHW
video_transforms.RandomHorizontalFlipVideo(),
video_transforms.UCFCenterCropVideo(resolution),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
return transform_video
def get_transforms_image(image_size=256):
transform = transforms.Compose(
[
transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, image_size)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
return transform
class DatasetFromCSV(torch.utils.data.Dataset):
"""load video according to the csv file.
Args:
target_video_len (int): the number of video frames will be load.
align_transform (callable): Align different videos in a specified size.
temporal_sample (callable): Sample the target length of a video.
"""
def __init__(
self,
csv_path,
num_frames=16,
frame_interval=1,
transform=None,
root=None,
):
self.csv_path = csv_path
with open(csv_path, "r") as f:
reader = csv.reader(f)
self.samples = list(reader)
ext = self.samples[0][0].split(".")[-1]
if ext.lower() in ("mp4", "avi", "mov", "mkv"):
self.is_video = True
else:
assert f".{ext.lower()}" in IMG_EXTENSIONS, f"Unsupported file format: {ext}"
self.is_video = False
self.transform = transform
self.num_frames = num_frames
self.frame_interval = frame_interval
self.temporal_sample = video_transforms.TemporalRandomCrop(num_frames * frame_interval)
self.root = root
def getitem(self, index):
sample = self.samples[index]
path = sample[0]
if self.root:
path = os.path.join(self.root, path)
text = sample[1]
if self.is_video:
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
total_frames = len(vframes)
# Sampling video frames
start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
assert (
end_frame_ind - start_frame_ind >= self.num_frames
), f"{path} with index {index} has not enough frames."
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int)
video = vframes[frame_indice]
video = self.transform(video) # T C H W
else:
image = pil_loader(path)
image = self.transform(image)
video = image.unsqueeze(0).repeat(self.num_frames, 1, 1, 1)
# TCHW -> CTHW
video = video.permute(1, 0, 2, 3)
return {"video": video, "text": text}
def __getitem__(self, index):
for _ in range(10):
try:
return self.getitem(index)
except Exception as e:
print(e)
index = np.random.randint(len(self))
raise RuntimeError("Too many bad data.")
def __len__(self):
return len(self.samples)
import random
from typing import Iterator, Optional
import numpy as np
import torch
from PIL import Image
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import _get_default_group
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from torchvision.io import write_video
from torchvision.utils import save_image
def save_sample(x, fps=8, save_path=None, normalize=True, value_range=(-1, 1)):
"""
Args:
x (Tensor): shape [C, T, H, W]
"""
assert x.ndim == 4
if x.shape[1] == 1: # T = 1: save as image
save_path += ".png"
x = x.squeeze(1)
save_image([x], save_path, normalize=normalize, value_range=value_range)
else:
save_path += ".mp4"
if normalize:
low, high = value_range
x.clamp_(min=low, max=high)
x.sub_(low).div_(max(high - low, 1e-5))
x = x.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 3, 0).to("cpu", torch.uint8)
write_video(save_path, x, fps=fps, video_codec="h264")
print(f"Saved to {save_path}")
class StatefulDistributedSampler(DistributedSampler):
def __init__(
self,
dataset: Dataset,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
) -> None:
super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last)
self.start_index: int = 0
def __iter__(self) -> Iterator:
iterator = super().__iter__()
indices = list(iterator)
indices = indices[self.start_index :]
return iter(indices)
def __len__(self) -> int:
return self.num_samples - self.start_index
def set_start_index(self, start_index: int) -> None:
self.start_index = start_index
def prepare_dataloader(
dataset,
batch_size,
shuffle=False,
seed=1024,
drop_last=False,
pin_memory=False,
num_workers=0,
process_group: Optional[ProcessGroup] = None,
**kwargs,
):
r"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by
`torch.utils.data.DataLoader` and `StatefulDistributedSampler`.
Args:
dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
seed (int, optional): Random worker seed for sampling, defaults to 1024.
add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
is not divisible by the batch size. If False and the size of dataset is not divisible by
the batch size, then the last batch will be smaller, defaults to False.
pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
`DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.
Returns:
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
"""
_kwargs = kwargs.copy()
process_group = process_group or _get_default_group()
sampler = StatefulDistributedSampler(
dataset, num_replicas=process_group.size(), rank=process_group.rank(), shuffle=shuffle
)
# Deterministic dataloader
def seed_worker(worker_id):
worker_seed = seed
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
random.seed(worker_seed)
return DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
worker_init_fn=seed_worker,
drop_last=drop_last,
pin_memory=pin_memory,
num_workers=num_workers,
**_kwargs,
)
def center_crop_arr(pil_image, image_size):
"""
Center cropping implementation from ADM.
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
"""
while min(*pil_image.size) >= 2 * image_size:
pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)
scale = image_size / min(*pil_image.size)
pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
arr = np.array(pil_image)
crop_y = (arr.shape[0] - image_size) // 2
crop_x = (arr.shape[1] - image_size) // 2
return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size])
# Copyright 2024 Vchitect/Latte
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.# Modified from Latte
# - This file is adapted from https://github.com/Vchitect/Latte/blob/main/datasets/video_transforms.py
import numbers
import random
import torch
def _is_tensor_video_clip(clip):
if not torch.is_tensor(clip):
raise TypeError("clip should be Tensor. Got %s" % type(clip))
if not clip.ndimension() == 4:
raise ValueError("clip should be 4D. Got %dD" % clip.dim())
return True
def center_crop_arr(pil_image, image_size):
"""
Center cropping implementation from ADM.
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
"""
while min(*pil_image.size) >= 2 * image_size:
pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)
scale = image_size / min(*pil_image.size)
pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
arr = np.array(pil_image)
crop_y = (arr.shape[0] - image_size) // 2
crop_x = (arr.shape[1] - image_size) // 2
return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size])
def crop(clip, i, j, h, w):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
"""
if len(clip.size()) != 4:
raise ValueError("clip should be a 4D tensor")
return clip[..., i : i + h, j : j + w]
def resize(clip, target_size, interpolation_mode):
if len(target_size) != 2:
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False)
def resize_scale(clip, target_size, interpolation_mode):
if len(target_size) != 2:
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
H, W = clip.size(-2), clip.size(-1)
scale_ = target_size[0] / min(H, W)
return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False)
def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
"""
Do spatial cropping and resizing to the video clip
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
i (int): i in (i,j) i.e coordinates of the upper left corner.
j (int): j in (i,j) i.e coordinates of the upper left corner.
h (int): Height of the cropped region.
w (int): Width of the cropped region.
size (tuple(int, int)): height and width of resized clip
Returns:
clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W)
"""
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
clip = crop(clip, i, j, h, w)
clip = resize(clip, size, interpolation_mode)
return clip
def center_crop(clip, crop_size):
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
h, w = clip.size(-2), clip.size(-1)
th, tw = crop_size
if h < th or w < tw:
raise ValueError("height and width must be no smaller than crop_size")
i = int(round((h - th) / 2.0))
j = int(round((w - tw) / 2.0))
return crop(clip, i, j, th, tw)
def center_crop_using_short_edge(clip):
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
h, w = clip.size(-2), clip.size(-1)
if h < w:
th, tw = h, h
i = 0
j = int(round((w - tw) / 2.0))
else:
th, tw = w, w
i = int(round((h - th) / 2.0))
j = 0
return crop(clip, i, j, th, tw)
def random_shift_crop(clip):
"""
Slide along the long edge, with the short edge as crop size
"""
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
h, w = clip.size(-2), clip.size(-1)
if h <= w:
short_edge = h
else:
short_edge = w
th, tw = short_edge, short_edge
i = torch.randint(0, h - th + 1, size=(1,)).item()
j = torch.randint(0, w - tw + 1, size=(1,)).item()
return crop(clip, i, j, th, tw)
def to_tensor(clip):
"""
Convert tensor data type from uint8 to float, divide value by 255.0 and
permute the dimensions of clip tensor
Args:
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
Return:
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
"""
_is_tensor_video_clip(clip)
if not clip.dtype == torch.uint8:
raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype))
# return clip.float().permute(3, 0, 1, 2) / 255.0
return clip.float() / 255.0
def normalize(clip, mean, std, inplace=False):
"""
Args:
clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
mean (tuple): pixel RGB mean. Size is (3)
std (tuple): pixel standard deviation. Size is (3)
Returns:
normalized clip (torch.tensor): Size is (T, C, H, W)
"""
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
if not inplace:
clip = clip.clone()
mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
# print(mean)
std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
return clip
def hflip(clip):
"""
Args:
clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
Returns:
flipped clip (torch.tensor): Size is (T, C, H, W)
"""
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
return clip.flip(-1)
class RandomCropVideo:
def __init__(self, size):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: randomly cropped video clip.
size is (T, C, OH, OW)
"""
i, j, h, w = self.get_params(clip)
return crop(clip, i, j, h, w)
def get_params(self, clip):
h, w = clip.shape[-2:]
th, tw = self.size
if h < th or w < tw:
raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}")
if w == tw and h == th:
return 0, 0, h, w
i = torch.randint(0, h - th + 1, size=(1,)).item()
j = torch.randint(0, w - tw + 1, size=(1,)).item()
return i, j, th, tw
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size})"
class CenterCropResizeVideo:
"""
First use the short side for cropping length,
center crop video, then resize to the specified size
"""
def __init__(
self,
size,
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
if len(size) != 2:
raise ValueError(f"size should be tuple (height, width), instead got {size}")
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: scale resized / center cropped video clip.
size is (T, C, crop_size, crop_size)
"""
clip_center_crop = center_crop_using_short_edge(clip)
clip_center_crop_resize = resize(
clip_center_crop, target_size=self.size, interpolation_mode=self.interpolation_mode
)
return clip_center_crop_resize
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
class UCFCenterCropVideo:
"""
First scale to the specified size in equal proportion to the short edge,
then center cropping
"""
def __init__(
self,
size,
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
if len(size) != 2:
raise ValueError(f"size should be tuple (height, width), instead got {size}")
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: scale resized / center cropped video clip.
size is (T, C, crop_size, crop_size)
"""
clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode)
clip_center_crop = center_crop(clip_resize, self.size)
return clip_center_crop
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
class KineticsRandomCropResizeVideo:
"""
Slide along the long edge, with the short edge as crop size. And resie to the desired size.
"""
def __init__(
self,
size,
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
if len(size) != 2:
raise ValueError(f"size should be tuple (height, width), instead got {size}")
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
clip_random_crop = random_shift_crop(clip)
clip_resize = resize(clip_random_crop, self.size, self.interpolation_mode)
return clip_resize
class CenterCropVideo:
def __init__(
self,
size,
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
if len(size) != 2:
raise ValueError(f"size should be tuple (height, width), instead got {size}")
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: center cropped video clip.
size is (T, C, crop_size, crop_size)
"""
clip_center_crop = center_crop(clip, self.size)
return clip_center_crop
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
class NormalizeVideo:
"""
Normalize the video clip by mean subtraction and division by standard deviation
Args:
mean (3-tuple): pixel RGB mean
std (3-tuple): pixel RGB standard deviation
inplace (boolean): whether do in-place normalization
"""
def __init__(self, mean, std, inplace=False):
self.mean = mean
self.std = std
self.inplace = inplace
def __call__(self, clip):
"""
Args:
clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W)
"""
return normalize(clip, self.mean, self.std, self.inplace)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})"
class ToTensorVideo:
"""
Convert tensor data type from uint8 to float, divide value by 255.0 and
permute the dimensions of clip tensor
"""
def __init__(self):
pass
def __call__(self, clip):
"""
Args:
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
Return:
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
"""
return to_tensor(clip)
def __repr__(self) -> str:
return self.__class__.__name__
class RandomHorizontalFlipVideo:
"""
Flip the video clip along the horizontal direction with a given probability
Args:
p (float): probability of the clip being flipped. Default value is 0.5
"""
def __init__(self, p=0.5):
self.p = p
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Size is (T, C, H, W)
Return:
clip (torch.tensor): Size is (T, C, H, W)
"""
if random.random() < self.p:
clip = hflip(clip)
return clip
def __repr__(self) -> str:
return f"{self.__class__.__name__}(p={self.p})"
# ------------------------------------------------------------
# --------------------- Sampling ---------------------------
# ------------------------------------------------------------
class TemporalRandomCrop(object):
"""Temporally crop the given frame indices at a random location.
Args:
size (int): Desired length of frames will be seen in the model.
"""
def __init__(self, size):
self.size = size
def __call__(self, total_frames):
rand_end = max(0, total_frames - self.size - 1)
begin_index = random.randint(0, rand_end)
end_index = min(begin_index + self.size, total_frames)
return begin_index, end_index
if __name__ == "__main__":
import os
import numpy as np
import torchvision.io as io
from torchvision import transforms
from torchvision.utils import save_image
vframes, aframes, info = io.read_video(filename="./v_Archery_g01_c03.avi", pts_unit="sec", output_format="TCHW")
trans = transforms.Compose(
[
ToTensorVideo(),
RandomHorizontalFlipVideo(),
UCFCenterCropVideo(512),
# NormalizeVideo(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
target_video_len = 32
frame_interval = 1
total_frames = len(vframes)
print(total_frames)
temporal_sample = TemporalRandomCrop(target_video_len * frame_interval)
# Sampling video frames
start_frame_ind, end_frame_ind = temporal_sample(total_frames)
# print(start_frame_ind)
# print(end_frame_ind)
assert end_frame_ind - start_frame_ind >= target_video_len
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, target_video_len, dtype=int)
print(frame_indice)
select_vframes = vframes[frame_indice]
print(select_vframes.shape)
print(select_vframes.dtype)
select_vframes_trans = trans(select_vframes)
print(select_vframes_trans.shape)
print(select_vframes_trans.dtype)
select_vframes_trans_int = ((select_vframes_trans * 0.5 + 0.5) * 255).to(dtype=torch.uint8)
print(select_vframes_trans_int.dtype)
print(select_vframes_trans_int.permute(0, 2, 3, 1).shape)
io.write_video("./test.avi", select_vframes_trans_int.permute(0, 2, 3, 1), fps=8)
for i in range(target_video_len):
save_image(
select_vframes_trans[i], os.path.join("./test000", "%04d.png" % i), normalize=True, value_range=(-1, 1)
)
from .dit import *
from .latte import *
from .pixart import *
from .stdit import *
from .text_encoder import *
from .vae import *
from .dit import DiT, DiT_XL_2, DiT_XL_2x2
# Modified from Meta DiT
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# DiT: https://github.com/facebookresearch/DiT/tree/main
# GLIDE: https://github.com/openai/glide-text2im
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
# --------------------------------------------------------
import numpy as np
import torch
import torch.nn as nn
import torch.utils.checkpoint
from einops import rearrange
from timm.models.vision_transformer import Mlp
from opensora.acceleration.checkpoint import auto_grad_checkpoint
from opensora.models.layers.blocks import (
Attention,
CaptionEmbedder,
FinalLayer,
LabelEmbedder,
PatchEmbed3D,
TimestepEmbedder,
approx_gelu,
get_1d_sincos_pos_embed,
get_2d_sincos_pos_embed,
get_layernorm,
modulate,
)
from opensora.registry import MODELS
from opensora.utils.ckpt_utils import load_checkpoint
class DiTBlock(nn.Module):
"""
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
"""
def __init__(
self,
hidden_size,
num_heads,
mlp_ratio=4.0,
enable_flashattn=False,
enable_layernorm_kernel=False,
):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.enable_flashattn = enable_flashattn
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.norm1 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)
self.attn = Attention(
hidden_size,
num_heads=num_heads,
qkv_bias=True,
enable_flashattn=enable_flashattn,
)
self.norm2 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
def forward(self, x, c):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1, x, shift_msa, scale_msa))
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2, x, shift_mlp, scale_mlp))
return x
@MODELS.register_module()
class DiT(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(
self,
input_size=(16, 32, 32),
in_channels=4,
patch_size=(1, 2, 2),
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4.0,
class_dropout_prob=0.1,
learn_sigma=True,
condition="text",
no_temporal_pos_emb=False,
caption_channels=512,
model_max_length=77,
dtype=torch.float32,
enable_flashattn=False,
enable_layernorm_kernel=False,
):
super().__init__()
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels * 2 if learn_sigma else in_channels
self.hidden_size = hidden_size
self.patch_size = patch_size
self.input_size = input_size
num_patches = np.prod([input_size[i] // patch_size[i] for i in range(3)])
self.num_patches = num_patches
self.num_temporal = input_size[0] // patch_size[0]
self.num_spatial = num_patches // self.num_temporal
self.num_heads = num_heads
self.dtype = dtype
self.use_text_encoder = not condition.startswith("label")
if enable_flashattn:
assert dtype in [
torch.float16,
torch.bfloat16,
], f"Flash attention only supports float16 and bfloat16, but got {self.dtype}"
self.no_temporal_pos_emb = no_temporal_pos_emb
self.mlp_ratio = mlp_ratio
self.depth = depth
self.register_buffer("pos_embed_spatial", self.get_spatial_pos_embed())
self.register_buffer("pos_embed_temporal", self.get_temporal_pos_embed())
self.x_embedder = PatchEmbed3D(patch_size, in_channels, embed_dim=hidden_size)
if not self.use_text_encoder:
num_classes = int(condition.split("_")[-1])
self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
else:
self.y_embedder = CaptionEmbedder(
in_channels=caption_channels,
hidden_size=hidden_size,
uncond_prob=class_dropout_prob,
act_layer=approx_gelu,
token_num=1, # pooled token
)
self.t_embedder = TimestepEmbedder(hidden_size)
self.blocks = nn.ModuleList(
[
DiTBlock(
hidden_size,
num_heads,
mlp_ratio=mlp_ratio,
enable_flashattn=enable_flashattn,
enable_layernorm_kernel=enable_layernorm_kernel,
)
for _ in range(depth)
]
)
self.final_layer = FinalLayer(hidden_size, np.prod(self.patch_size), self.out_channels)
self.initialize_weights()
self.enable_flashattn = enable_flashattn
self.enable_layernorm_kernel = enable_layernorm_kernel
def get_spatial_pos_embed(self):
pos_embed = get_2d_sincos_pos_embed(
self.hidden_size,
self.input_size[1] // self.patch_size[1],
)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)
return pos_embed
def get_temporal_pos_embed(self):
pos_embed = get_1d_sincos_pos_embed(
self.hidden_size,
self.input_size[0] // self.patch_size[0],
)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)
return pos_embed
def unpatchify(self, x):
c = self.out_channels
t, h, w = [self.input_size[i] // self.patch_size[i] for i in range(3)]
pt, ph, pw = self.patch_size
x = x.reshape(shape=(x.shape[0], t, h, w, pt, ph, pw, c))
x = rearrange(x, "n t h w r p q c -> n c t r h p w q")
imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
return imgs
def forward(self, x, t, y):
"""
Forward pass of DiT.
x: (B, C, T, H, W) tensor of inputs
t: (B,) tensor of diffusion timesteps
y: list of text
"""
# origin inputs should be float32, cast to specified dtype
x = x.to(self.dtype)
# embedding
x = self.x_embedder(x) # (B, N, D)
x = rearrange(x, "b (t s) d -> b t s d", t=self.num_temporal, s=self.num_spatial)
x = x + self.pos_embed_spatial
if not self.no_temporal_pos_emb:
x = rearrange(x, "b t s d -> b s t d")
x = x + self.pos_embed_temporal
x = rearrange(x, "b s t d -> b (t s) d")
else:
x = rearrange(x, "b t s d -> b (t s) d")
t = self.t_embedder(t, dtype=x.dtype) # (N, D)
y = self.y_embedder(y, self.training) # (N, D)
if self.use_text_encoder:
y = y.squeeze(1).squeeze(1)
condition = t + y
# blocks
for _, block in enumerate(self.blocks):
c = condition
x = auto_grad_checkpoint(block, x, c) # (B, N, D)
# final process
x = self.final_layer(x, condition) # (B, N, num_patches * out_channels)
x = self.unpatchify(x) # (B, out_channels, T, H, W)
# cast to float32 for better accuracy
x = x.to(torch.float32)
return x
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
if module.weight.requires_grad_:
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.x_embedder.proj.bias, 0)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# Zero-out adaLN modulation layers in DiT blocks:
for block in self.blocks:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
# Zero-out text embedding layers:
if self.use_text_encoder:
nn.init.normal_(self.y_embedder.y_proj.fc1.weight, std=0.02)
nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02)
@MODELS.register_module("DiT-XL/2")
def DiT_XL_2(from_pretrained=None, **kwargs):
model = DiT(
depth=28,
hidden_size=1152,
patch_size=(1, 2, 2),
num_heads=16,
**kwargs,
)
if from_pretrained is not None:
load_checkpoint(model, from_pretrained)
return model
@MODELS.register_module("DiT-XL/2x2")
def DiT_XL_2x2(from_pretrained=None, **kwargs):
model = DiT(
depth=28,
hidden_size=1152,
patch_size=(2, 2, 2),
num_heads=16,
**kwargs,
)
if from_pretrained is not None:
load_checkpoint(model, from_pretrained)
return model
from .latte import Latte, Latte_XL_2, Latte_XL_2x2
# Copyright 2024 Vchitect/Latte
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.# Modified from Latte
#
#
# This file is mofied from https://github.com/Vchitect/Latte/blob/main/models/latte.py
#
# With references to:
# Latte: https://github.com/Vchitect/Latte
# DiT: https://github.com/facebookresearch/DiT/tree/main
import torch
from einops import rearrange, repeat
from opensora.acceleration.checkpoint import auto_grad_checkpoint
from opensora.models.dit import DiT
from opensora.registry import MODELS
from opensora.utils.ckpt_utils import load_checkpoint
@MODELS.register_module()
class Latte(DiT):
def forward(self, x, t, y):
"""
Forward pass of DiT.
x: (B, C, T, H, W) tensor of inputs
t: (B,) tensor of diffusion timesteps
y: list of text
"""
# origin inputs should be float32, cast to specified dtype
x = x.to(self.dtype)
# embedding
x = self.x_embedder(x) # (B, N, D)
x = rearrange(x, "b (t s) d -> b t s d", t=self.num_temporal, s=self.num_spatial)
x = x + self.pos_embed_spatial
x = rearrange(x, "b t s d -> b (t s) d")
t = self.t_embedder(t, dtype=x.dtype) # (N, D)
y = self.y_embedder(y, self.training) # (N, D)
if self.use_text_encoder:
y = y.squeeze(1).squeeze(1)
condition = t + y
condition_spatial = repeat(condition, "b d -> (b t) d", t=self.num_temporal)
condition_temporal = repeat(condition, "b d -> (b s) d", s=self.num_spatial)
# blocks
for i, block in enumerate(self.blocks):
if i % 2 == 0:
# spatial
x = rearrange(x, "b (t s) d -> (b t) s d", t=self.num_temporal, s=self.num_spatial)
c = condition_spatial
else:
# temporal
x = rearrange(x, "b (t s) d -> (b s) t d", t=self.num_temporal, s=self.num_spatial)
c = condition_temporal
if i == 1:
x = x + self.pos_embed_temporal
x = auto_grad_checkpoint(block, x, c) # (B, N, D)
if i % 2 == 0:
x = rearrange(x, "(b t) s d -> b (t s) d", t=self.num_temporal, s=self.num_spatial)
else:
x = rearrange(x, "(b s) t d -> b (t s) d", t=self.num_temporal, s=self.num_spatial)
# final process
x = self.final_layer(x, condition) # (B, N, num_patches * out_channels)
x = self.unpatchify(x) # (B, out_channels, T, H, W)
# cast to float32 for better accuracy
x = x.to(torch.float32)
return x
@MODELS.register_module("Latte-XL/2")
def Latte_XL_2(from_pretrained=None, **kwargs):
model = Latte(
depth=28,
hidden_size=1152,
patch_size=(1, 2, 2),
num_heads=16,
**kwargs,
)
if from_pretrained is not None:
load_checkpoint(model, from_pretrained)
return model
@MODELS.register_module("Latte-XL/2x2")
def Latte_XL_2x2(from_pretrained=None, **kwargs):
model = Latte(
depth=28,
hidden_size=1152,
patch_size=(2, 2, 2),
num_heads=16,
**kwargs,
)
if from_pretrained is not None:
load_checkpoint(model, from_pretrained)
return model
This diff is collapsed.
from .pixart import PixArt, PixArt_XL_2
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment