Commit 1f5da520 authored by yangzhong's avatar yangzhong
Browse files

git init

parents
Pipeline #3144 failed with stages
in 0 seconds
from collections.abc import Iterable
import torch.nn as nn
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
def set_grad_checkpoint(model, use_fp32_attention=False, gc_step=1):
assert isinstance(model, nn.Module)
def set_attr(module):
module.grad_checkpointing = True
module.fp32_attention = use_fp32_attention
module.grad_checkpointing_step = gc_step
model.apply(set_attr)
def auto_grad_checkpoint(module, *args, **kwargs):
if getattr(module, "grad_checkpointing", False):
if not isinstance(module, Iterable):
return checkpoint(module, *args, **kwargs)
gc_step = module[0].grad_checkpointing_step
return checkpoint_sequential(module, gc_step, *args, **kwargs)
return module(*args, **kwargs)
import torch
import torch.distributed as dist
# ====================
# All-To-All
# ====================
def _all_to_all(
input_: torch.Tensor,
world_size: int,
group: dist.ProcessGroup,
scatter_dim: int,
gather_dim: int,
):
input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)]
output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
dist.all_to_all(output_list, input_list, group=group)
return torch.cat(output_list, dim=gather_dim).contiguous()
class _AllToAll(torch.autograd.Function):
"""All-to-all communication.
Args:
input_: input matrix
process_group: communication group
scatter_dim: scatter dimension
gather_dim: gather dimension
"""
@staticmethod
def forward(ctx, input_, process_group, scatter_dim, gather_dim):
ctx.process_group = process_group
ctx.scatter_dim = scatter_dim
ctx.gather_dim = gather_dim
ctx.world_size = dist.get_world_size(process_group)
output = _all_to_all(input_, ctx.world_size, process_group, scatter_dim, gather_dim)
return output
@staticmethod
def backward(ctx, grad_output):
grad_output = _all_to_all(
grad_output,
ctx.world_size,
ctx.process_group,
ctx.gather_dim,
ctx.scatter_dim,
)
return (
grad_output,
None,
None,
None,
)
def all_to_all(
input_: torch.Tensor,
process_group: dist.ProcessGroup,
scatter_dim: int = 2,
gather_dim: int = 1,
):
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim)
def _gather(
input_: torch.Tensor,
world_size: int,
group: dist.ProcessGroup,
gather_dim: int,
):
if gather_list is None:
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
dist.gather(input_, gather_list, group=group, gather_dim=gather_dim)
return gather_list
# ====================
# Gather-Split
# ====================
def _split(input_, pg: dist.ProcessGroup, dim=-1):
# skip if only one rank involved
world_size = dist.get_world_size(pg)
rank = dist.get_rank(pg)
if world_size == 1:
return input_
# Split along last dimension.
dim_size = input_.size(dim)
assert dim_size % world_size == 0, (
f"The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), "
f"cannot split tensor evenly"
)
tensor_list = torch.split(input_, dim_size // world_size, dim=dim)
output = tensor_list[rank].contiguous()
return output
def _gather(input_, pg: dist.ProcessGroup, dim=-1):
# skip if only one rank involved
input_ = input_.contiguous()
world_size = dist.get_world_size(pg)
dist.get_rank(pg)
if world_size == 1:
return input_
# all gather
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
assert input_.device.type == "cuda"
torch.distributed.all_gather(tensor_list, input_, group=pg)
# concat
output = torch.cat(tensor_list, dim=dim).contiguous()
return output
class _GatherForwardSplitBackward(torch.autograd.Function):
"""Gather the input from model parallel region and concatenate.
Args:
input_: input matrix.
process_group: parallel mode.
dim: dimension
"""
@staticmethod
def symbolic(graph, input_):
return _gather(input_)
@staticmethod
def forward(ctx, input_, process_group, dim, grad_scale):
ctx.mode = process_group
ctx.dim = dim
ctx.grad_scale = grad_scale
return _gather(input_, process_group, dim)
@staticmethod
def backward(ctx, grad_output):
if ctx.grad_scale == "up":
grad_output = grad_output * dist.get_world_size(ctx.mode)
elif ctx.grad_scale == "down":
grad_output = grad_output / dist.get_world_size(ctx.mode)
return _split(grad_output, ctx.mode, ctx.dim), None, None, None
class _SplitForwardGatherBackward(torch.autograd.Function):
"""
Split the input and keep only the corresponding chuck to the rank.
Args:
input_: input matrix.
process_group: parallel mode.
dim: dimension
"""
@staticmethod
def symbolic(graph, input_):
return _split(input_)
@staticmethod
def forward(ctx, input_, process_group, dim, grad_scale):
ctx.mode = process_group
ctx.dim = dim
ctx.grad_scale = grad_scale
return _split(input_, process_group, dim)
@staticmethod
def backward(ctx, grad_output):
if ctx.grad_scale == "up":
grad_output = grad_output * dist.get_world_size(ctx.mode)
elif ctx.grad_scale == "down":
grad_output = grad_output / dist.get_world_size(ctx.mode)
return _gather(grad_output, ctx.mode, ctx.dim), None, None, None
def split_forward_gather_backward(input_, process_group, dim, grad_scale=1.0):
return _SplitForwardGatherBackward.apply(input_, process_group, dim, grad_scale)
def gather_forward_split_backward(input_, process_group, dim, grad_scale=None):
return _GatherForwardSplitBackward.apply(input_, process_group, dim, grad_scale)
import torch.distributed as dist
_GLOBAL_PARALLEL_GROUPS = dict()
def set_data_parallel_group(group: dist.ProcessGroup):
_GLOBAL_PARALLEL_GROUPS["data"] = group
def get_data_parallel_group():
return _GLOBAL_PARALLEL_GROUPS.get("data", None)
def set_sequence_parallel_group(group: dist.ProcessGroup):
_GLOBAL_PARALLEL_GROUPS["sequence"] = group
def get_sequence_parallel_group():
return _GLOBAL_PARALLEL_GROUPS.get("sequence", None)
import random
from typing import Optional
import numpy as np
import torch
from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.cluster import ProcessGroupMesh
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
DP_AXIS, SP_AXIS = 0, 1
class ZeroSeqParallelPlugin(LowLevelZeroPlugin):
def __init__(
self,
sp_size: int = 1,
stage: int = 2,
precision: str = "fp16",
initial_scale: float = 2**32,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
max_norm: float = 0.0,
norm_type: float = 2.0,
reduce_bucket_size_in_m: int = 12,
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = True,
cpu_offload: bool = False,
master_weights: bool = True,
verbose: bool = False,
) -> None:
super().__init__(
stage=stage,
precision=precision,
initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale,
max_norm=max_norm,
norm_type=norm_type,
reduce_bucket_size_in_m=reduce_bucket_size_in_m,
communication_dtype=communication_dtype,
overlap_communication=overlap_communication,
cpu_offload=cpu_offload,
master_weights=master_weights,
verbose=verbose,
)
self.sp_size = sp_size
assert self.world_size % sp_size == 0, "world_size must be divisible by sp_size"
self.dp_size = self.world_size // sp_size
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.sp_size)
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
self.sp_group = self.pg_mesh.get_group_along_axis(SP_AXIS)
self.dp_rank = self.pg_mesh.coordinate(DP_AXIS)
self.sp_rank = self.pg_mesh.coordinate(SP_AXIS)
def __del__(self):
"""Destroy the prcess groups in ProcessGroupMesh"""
self.pg_mesh.destroy_mesh_process_groups()
def prepare_dataloader(
self,
dataset,
batch_size,
shuffle=False,
seed=1024,
drop_last=False,
pin_memory=False,
num_workers=0,
distributed_sampler_cls=None,
**kwargs,
):
_kwargs = kwargs.copy()
distributed_sampler_cls = distributed_sampler_cls or DistributedSampler
sampler = distributed_sampler_cls(dataset, num_replicas=self.dp_size, rank=self.dp_rank, shuffle=shuffle)
# Deterministic dataloader
def seed_worker(worker_id):
worker_seed = seed
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
random.seed(worker_seed)
return DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
worker_init_fn=seed_worker,
drop_last=drop_last,
pin_memory=pin_memory,
num_workers=num_workers,
**_kwargs,
)
import torch
import torch.nn as nn
class T5LayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
# half-precision inputs is done in fp32
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states
@staticmethod
def from_native_module(module, *args, **kwargs):
assert module.__class__.__name__ == "FusedRMSNorm", (
"Recovering T5LayerNorm requires the original layer to be apex's Fused RMS Norm."
"Apex's fused norm is automatically used by Hugging Face Transformers https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L265C5-L265C48"
)
layer_norm = T5LayerNorm(module.normalized_shape, eps=module.eps)
layer_norm.weight.data.copy_(module.weight.data)
layer_norm = layer_norm.to(module.weight.device)
return layer_norm
from colossalai.shardformer.modeling.jit import get_jit_fused_dropout_add_func
from colossalai.shardformer.modeling.t5 import get_jit_fused_T5_layer_ff_forward, get_T5_layer_self_attention_forward
from colossalai.shardformer.policies.base_policy import Policy, SubModuleReplacementDescription
class T5EncoderPolicy(Policy):
def config_sanity_check(self):
assert not self.shard_config.enable_tensor_parallelism
assert not self.shard_config.enable_flash_attention
def preprocess(self):
return self.model
def module_policy(self):
from transformers.models.t5.modeling_t5 import T5LayerFF, T5LayerSelfAttention, T5Stack
policy = {}
# check whether apex is installed
try:
from opensora.acceleration.shardformer.modeling.t5 import T5LayerNorm
# recover hf from fused rms norm to T5 norm which is faster
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="layer_norm",
target_module=T5LayerNorm,
),
policy=policy,
target_key=T5LayerFF,
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(suffix="layer_norm", target_module=T5LayerNorm),
policy=policy,
target_key=T5LayerSelfAttention,
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(suffix="final_layer_norm", target_module=T5LayerNorm),
policy=policy,
target_key=T5Stack,
)
except (ImportError, ModuleNotFoundError):
pass
# use jit operator
if self.shard_config.enable_jit_fused:
self.append_or_create_method_replacement(
description={
"forward": get_jit_fused_T5_layer_ff_forward(),
"dropout_add": get_jit_fused_dropout_add_func(),
},
policy=policy,
target_key=T5LayerFF,
)
self.append_or_create_method_replacement(
description={
"forward": get_T5_layer_self_attention_forward(),
"dropout_add": get_jit_fused_dropout_add_func(),
},
policy=policy,
target_key=T5LayerSelfAttention,
)
return policy
def postprocess(self):
return self.model
from .datasets import VideoTextDataset
#from .datasets_celebv import DatasetFromCSV, get_transforms_image, get_transforms_video
#from .datasets_panda50m import DatasetFromCSV, get_transforms_image, get_transforms_video
#from .datasets_webvid10m import DatasetFromCSV, get_transforms_image, get_transforms_video
#from .datasets_ours1m import DatasetFromCSV, get_transforms_image, get_transforms_video
# from .datasets_ours1m1080p import DatasetFromCSV, get_transforms_image, get_transforms_video
#from .datasets_path2text import DatasetFromCSV, get_transforms_image, get_transforms_video
from .utils import prepare_dataloader, save_sample
import os
import random
import glob
import numpy as np
import torch
import torchvision
from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader
from einops import rearrange
from torch.utils import data as data
from opensora.registry import DATASETS
from .utils import VID_EXTENSIONS, get_transforms_image, get_transforms_video, read_file, temporal_random_crop
IMG_FPS = 120
@DATASETS.register_module()
class VideoTextDataset(torch.utils.data.Dataset):
"""load video according to the csv file.
Args:
target_video_len (int): the number of video frames will be load.
align_transform (callable): Align different videos in a specified size.
temporal_sample (callable): Sample the target length of a video.
"""
def __init__(
self,
data_path,
num_frames=16,
frame_interval=1,
image_size=(256, 256),
transform_name="direct_crop",
):
self.data_path = data_path
self.data = read_file(data_path)
self.num_frames = num_frames
self.frame_interval = frame_interval
self.image_size = image_size
self.transforms = {
# "image": get_transforms_image(transform_name, image_size),
"video": get_transforms_video(transform_name, image_size),
}
def _print_data_number(self):
num_videos = 0
num_images = 0
for path in self.data["path"]:
if self.get_type(path) == "video":
num_videos += 1
else:
num_images += 1
print(f"Dataset contains {num_videos} videos and {num_images} images.")
def get_type(self, path):
ext = os.path.splitext(path)[-1].lower()
if ext.lower() in VID_EXTENSIONS:
return "video"
else:
assert ext.lower() in IMG_EXTENSIONS, f"Unsupported file format: {ext}"
return "image"
def getitem(self, index):
sample = self.data.iloc[index]
path = sample["path"]
text = sample["text"]
file_type = self.get_type(path)
if file_type == "video":
# loading
vframes, _, info = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
fps = info['video_fps']
# Sampling video frames
video = temporal_random_crop(vframes, self.num_frames, self.frame_interval)
# transform
transform = self.transforms["video"]
video = transform(video) # T C H W
else:
# loading
image = pil_loader(path)
# transform
transform = self.transforms["image"]
image = transform(image)
# repeat
video = image.unsqueeze(0).repeat(self.num_frames, 1, 1, 1)
# TCHW -> CTHW
video = video.permute(1, 0, 2, 3)
return {"video": video, "text": text, 'fps': fps}
def __getitem__(self, index):
for _ in range(10):
try:
return self.getitem(index)
except Exception as e:
path = self.data.iloc[index]["path"]
print(f"data {path}: {e}")
index = np.random.randint(len(self))
raise RuntimeError("Too many bad data.")
def __len__(self):
return len(self.data)
@DATASETS.register_module()
class VariableVideoTextDataset(VideoTextDataset):
def __init__(
self,
data_path,
num_frames=None,
frame_interval=1,
image_size=None,
transform_name=None,
):
super().__init__(data_path, num_frames, frame_interval, image_size, transform_name=None)
self.transform_name = transform_name
self.data["id"] = np.arange(len(self.data))
def get_data_info(self, index):
T = self.data.iloc[index]["num_frames"]
H = self.data.iloc[index]["height"]
W = self.data.iloc[index]["width"]
return T, H, W
def getitem(self, index):
# a hack to pass in the (time, height, width) info from sampler
index, num_frames, height, width = [int(val) for val in index.split("-")]
sample = self.data.iloc[index]
path = sample["path"]
text = sample["text"]
file_type = self.get_type(path)
ar = height / width
video_fps = 24 # default fps
if file_type == "video":
# loading
vframes, _, infos = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
if "video_fps" in infos:
video_fps = infos["video_fps"]
# Sampling video frames
video = temporal_random_crop(vframes, num_frames, self.frame_interval)
# transform
transform = get_transforms_video(self.transform_name, (height, width))
video = transform(video) # T C H W
else:
# loading
image = pil_loader(path)
video_fps = IMG_FPS
# transform
transform = get_transforms_image(self.transform_name, (height, width))
image = transform(image)
# repeat
video = image.unsqueeze(0)
# TCHW -> CTHW
video = video.permute(1, 0, 2, 3)
return {
"video": video,
"text": text,
"num_frames": num_frames,
"height": height,
"width": width,
"ar": ar,
"fps": video_fps,
}
@DATASETS.register_module()
class PairedCaptionDataset(data.Dataset):
def __init__(
self,
root_folder=None,
null_text_ratio=0.5,
):
super(PairedCaptionDataset, self).__init__()
self.null_text_ratio = null_text_ratio
self.lr_list = []
self.gt_list = []
self.tag_path_list = []
# root_folders = root_folders.split(',')
# for root_folder in root_folders:
lr_path = root_folder + '/lq'
tag_path = root_folder + '/text'
gt_path = root_folder + '/gt'
self.lr_list += glob.glob(os.path.join(lr_path, '*.mp4'))
self.gt_list += glob.glob(os.path.join(gt_path, '*.mp4'))
self.tag_path_list += glob.glob(os.path.join(tag_path, '*.txt'))
assert len(self.lr_list) == len(self.gt_list)
assert len(self.lr_list) == len(self.tag_path_list)
def __getitem__(self, index):
gt_path = self.gt_list[index]
vframes_gt, _, _ = torchvision.io.read_video(filename=gt_path, pts_unit="sec", output_format="TCHW")
vframes_gt = (rearrange(vframes_gt, "T C H W -> C T H W") / 255) * 2 - 1
# gt = self.trandform(vframes_gt)
lq_path = self.lr_list[index]
vframes_lq, _, _ = torchvision.io.read_video(filename=lq_path, pts_unit="sec", output_format="TCHW")
vframes_lq = (rearrange(vframes_lq, "T C H W -> C T H W") / 255) * 2 - 1
# lq = self.trandform(vframes_lq)
if random.random() < self.null_text_ratio:
tag = ''
else:
tag_path = self.tag_path_list[index]
file = open(tag_path, 'r')
tag = file.read()
file.close()
return {"gt": vframes_gt, "lq": vframes_lq, "text": tag}
def __len__(self):
return len(self.gt_list)
import csv
import os
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader
from . import video_transforms
from .utils import center_crop_arr
# import video_transforms
# from utils import center_crop_arr
import json
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
import ipdb
def get_transforms_video(resolution=256):
transform_video = transforms.Compose(
[
video_transforms.ToTensorVideo(), # TCHW
video_transforms.RandomHorizontalFlipVideo(),
video_transforms.UCFCenterCropVideo(resolution),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
return transform_video
def get_transforms_image(image_size=256):
transform = transforms.Compose(
[
transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, image_size)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
return transform
# open-sora-plan+magictime dataset
class DatasetFromCSV(torch.utils.data.Dataset):
"""load video according to the csv file.
Args:
target_video_len (int): the number of video frames will be load.
align_transform (callable): Align different videos in a specified size.
temporal_sample (callable): Sample the target length of a video.
"""
def __init__(
self,
csv_path,
num_frames=16,
frame_interval=1,
transform=None,
root=None,
):
video_samples = []
with open(csv_path, "r") as f:
reader = csv.reader(f)
csv_list = list(reader)
for v_s in csv_list[1:]: # no csv head
vid_path = v_s[0]
vid_name = vid_path.split('/')[-1]
vid_path = os.path.join(root, vid_name)
vid_caption = v_s[1]
if os.path.exists(vid_path):
video_samples.append([vid_path, vid_caption])
self.samples = video_samples # 35666 vids
self.is_video = True
self.transform = transform
self.num_frames = num_frames
self.frame_interval = frame_interval
self.temporal_sample = video_transforms.TemporalRandomCrop(num_frames * frame_interval)
self.root = root
def getitem(self, index):
sample = self.samples[index]
path = sample[0]
text = sample[1]
if self.is_video:
is_exit = os.path.exists(path)
if is_exit:
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
total_frames = len(vframes)
else:
total_frames = 0
loop_index = index
while(total_frames < self.num_frames or is_exit == False):
#print("total_frames:", total_frames, "<", self.num_frames, ", or", path, "does not exit!!!")
loop_index += 1
if loop_index >= len(self.samples):
loop_index = 0
sample = self.samples[loop_index]
path = sample[0]
text = sample[1]
is_exit = os.path.exists(path)
if is_exit:
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
total_frames = len(vframes)
else:
total_frames = 0
# video exits and total_frames >= self.num_frames
# Sampling video frames
start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
assert (
end_frame_ind - start_frame_ind >= self.num_frames
), f"{path} with index {index} has not enough frames."
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int)
#print("total_frames:", total_frames, "frame_indice:", frame_indice, "sample:", sample)
video = vframes[frame_indice]
video = self.transform(video) # T C H W
else:
image = pil_loader(path)
image = self.transform(image)
video = image.unsqueeze(0).repeat(self.num_frames, 1, 1, 1)
# TCHW -> CTHW
video = video.permute(1, 0, 2, 3)
return {"video": video, "text": text}
def __getitem__(self, index):
for _ in range(10):
try:
return self.getitem(index)
except Exception as e:
print(e)
index = np.random.randint(len(self))
raise RuntimeError("Too many bad data.")
def __len__(self):
return len(self.samples)
if __name__ == '__main__':
data_path = '/mnt/bn/yh-volume0/dataset/CelebvHQ/CelebvHQ_caption_llava-34B.csv'
root='/mnt/bn/yh-volume0/dataset/CelebvHQ/35666'
dataset = DatasetFromCSV(
data_path,
transform=get_transforms_video(),
num_frames=16,
frame_interval=3,
root=root,
)
sampler = DistributedSampler(
dataset,
num_replicas=1,
rank=0,
shuffle=True,
seed=1
)
loader = DataLoader(
dataset,
batch_size=1,
shuffle=False,
sampler=sampler,
num_workers=0,
pin_memory=True,
drop_last=True
)
for video_data in loader:
print(video_data)
\ No newline at end of file
import csv
import os
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader
from . import video_transforms
from .utils import center_crop_arr
def get_transforms_video(resolution=256):
transform_video = transforms.Compose(
[
video_transforms.ToTensorVideo(), # TCHW
video_transforms.RandomHorizontalFlipVideo(),
video_transforms.UCFCenterCropVideo(resolution),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
return transform_video
def get_transforms_image(image_size=256):
transform = transforms.Compose(
[
transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, image_size)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
return transform
class DatasetFromCSV(torch.utils.data.Dataset):
"""load video according to the csv file.
Args:
target_video_len (int): the number of video frames will be load.
align_transform (callable): Align different videos in a specified size.
temporal_sample (callable): Sample the target length of a video.
"""
def __init__(
self,
csv_path,
num_frames=16,
frame_interval=1,
transform=None,
root=None,
):
self.csv_path = csv_path
with open(csv_path, "r") as f:
reader = csv.reader(f)
self.samples = list(reader)
ext = self.samples[0][0].split(".")[-1]
if ext.lower() in ("mp4", "avi", "mov", "mkv"):
self.is_video = True
else:
assert f".{ext.lower()}" in IMG_EXTENSIONS, f"Unsupported file format: {ext}"
self.is_video = False
self.transform = transform
self.num_frames = num_frames
self.frame_interval = frame_interval
self.temporal_sample = video_transforms.TemporalRandomCrop(num_frames * frame_interval)
self.root = root
def getitem(self, index):
sample = self.samples[index]
path = sample[0]
if self.root:
path = os.path.join(self.root, path)
text = sample[1]
if self.is_video:
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
total_frames = len(vframes)
# Sampling video frames
start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
assert (
end_frame_ind - start_frame_ind >= self.num_frames
), f"{path} with index {index} has not enough frames."
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int)
video = vframes[frame_indice]
video = self.transform(video) # T C H W
else:
image = pil_loader(path)
image = self.transform(image)
video = image.unsqueeze(0).repeat(self.num_frames, 1, 1, 1)
# TCHW -> CTHW
video = video.permute(1, 0, 2, 3)
return {"video": video, "text": text}
def __getitem__(self, index):
for _ in range(10):
try:
return self.getitem(index)
except Exception as e:
print(e)
index = np.random.randint(len(self))
raise RuntimeError("Too many bad data.")
def __len__(self):
return len(self.samples)
import csv
import os
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader
from . import video_transforms
from .utils import center_crop_arr
# import video_transforms
# from utils import center_crop_arr
import json
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
def get_transforms_video(resolution=256):
transform_video = transforms.Compose(
[
video_transforms.ToTensorVideo(), # TCHW
video_transforms.RandomHorizontalFlipVideo(),
video_transforms.UCFCenterCropVideo(resolution),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
return transform_video
def get_transforms_image(image_size=256):
transform = transforms.Compose(
[
transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, image_size)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
return transform
# open-sora-plan+magictime dataset
class DatasetFromCSV(torch.utils.data.Dataset):
"""load video according to the csv file.
Args:
target_video_len (int): the number of video frames will be load.
align_transform (callable): Align different videos in a specified size.
temporal_sample (callable): Sample the target length of a video.
"""
def __init__(
self,
csv_path_magictime="/mnt/bn/videodataset-uswest/MagicTime/caption/ChronoMagic_train.csv",
osp_path="/mnt/bn/videodataset-uswest/open_sora_dataset/raw/caption/sharegpt4v_path_cap_64x512x512.json", # for open sora plan
celebvhq_path="/mnt/bn/videodataset-uswest/CelebvHQ/CelebvHQ_caption_llava-34B_2k.csv", # for celebvhq
panda60w_path = "/mnt/bn/videodataset-uswest/VDiT/code/Open-Sora_caption/video_caption.csv", # for panda0.6m
num_frames=16,
frame_interval=1,
transform=None,
root_magictime="/mnt/bn/videodataset-uswest/MagicTime/video",
osp_root="/mnt/bn/videodataset-uswest/open_sora_dataset/raw/videos", # for open sora plan
celebvhq_root="/mnt/bn/videodataset-uswest/CelebvHQ/35666", # for celebvhq
panda60w_root = "/mnt/bn/videodataset-uswest/VDiT/dataset/panda-ours", # for panda0.6m
):
video_samples = []
with open(csv_path_magictime, "r") as f:
reader = csv.reader(f)
csv_list = list(reader)
for v_s in csv_list[1:]: # no csv head
vid_name = v_s[0]
vid_path = os.path.join(root_magictime, vid_name+".mp4")
vid_caption = v_s[1]
if os.path.exists(vid_path):
video_samples.append([vid_path, vid_caption])
print("magictime samples:", len(video_samples))
# magictime 2255
with open(osp_path, 'r', encoding='utf-8') as file:
extra_data = json.load(file)
for v_s in extra_data:
vid_name = v_s["path"].split('data_split_tt')[1]
vid_name = vid_name.replace(' ', '_')
vid_path = osp_root + vid_name
vid_caption = v_s["cap"]
if len(vid_caption) != 0 and os.path.exists(vid_path):
video_samples.append([vid_path, vid_caption[0]])
print("open-sora-plan+magictime samples:", len(video_samples))
# open-sora-plan 423585 -> 423567
with open(celebvhq_path, "r") as f:
reader = csv.reader(f)
csv_list = list(reader)
for v_s in csv_list[1:]: # no csv head
vid_path = v_s[0]
vid_name = vid_path.split('/')[-1]
vid_path = os.path.join(celebvhq_root, vid_name)
vid_caption = v_s[1]
if os.path.exists(vid_path):
video_samples.append([vid_path, vid_caption])
print("open-sora-plan+magictime+celevb samples:", len(video_samples))
# celevb 35596
with open(panda60w_path, "r") as f:
reader = csv.reader(f)
csv_list = list(reader)
for v_s in csv_list[1:]: # no csv head
vid_path = v_s[0]
vid_caption = v_s[1]
if os.path.exists(vid_path):
video_samples.append([vid_path, vid_caption])
print("open-sora-plan+magictime+celevb+panda0.6m samples:", len(video_samples))
# panda0.6m
self.samples = video_samples #
self.is_video = True
self.transform = transform
self.num_frames = num_frames
self.frame_interval = frame_interval
self.temporal_sample = video_transforms.TemporalRandomCrop(num_frames * frame_interval)
#self.root = root
def getitem(self, index):
sample = self.samples[index]
path = sample[0]
text = sample[1]
if self.is_video:
is_exit = os.path.exists(path)
if is_exit:
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
total_frames = len(vframes)
else:
total_frames = 0
loop_index = index
while(total_frames < self.num_frames or is_exit == False):
#print("total_frames:", total_frames, "<", self.num_frames, ", or", path, "does not exit!!!")
loop_index += 1
if loop_index >= len(self.samples):
loop_index = 0
sample = self.samples[loop_index]
path = sample[0]
text = sample[1]
is_exit = os.path.exists(path)
if is_exit:
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
total_frames = len(vframes)
else:
total_frames = 0
# video exits and total_frames >= self.num_frames
# Sampling video frames
start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
assert (
end_frame_ind - start_frame_ind >= self.num_frames
), f"{path} with index {index} has not enough frames."
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int)
#print("total_frames:", total_frames, "frame_indice:", frame_indice, "sample:", sample)
video = vframes[frame_indice]
video = self.transform(video) # T C H W
else:
image = pil_loader(path)
image = self.transform(image)
video = image.unsqueeze(0).repeat(self.num_frames, 1, 1, 1)
# TCHW -> CTHW
video = video.permute(1, 0, 2, 3)
return {"video": video, "text": text}
def __getitem__(self, index):
for _ in range(10):
try:
return self.getitem(index)
except Exception as e:
print(e)
index = np.random.randint(len(self))
raise RuntimeError("Too many bad data.")
def __len__(self):
return len(self.samples)
if __name__ == '__main__':
dataset = DatasetFromCSV(
transform=get_transforms_video(),
num_frames=16,
frame_interval=3,
)
sampler = DistributedSampler(
dataset,
num_replicas=1,
rank=0,
shuffle=True,
seed=1
)
loader = DataLoader(
dataset,
batch_size=1,
shuffle=False,
sampler=sampler,
num_workers=0,
pin_memory=True,
drop_last=True
)
for video_data in loader:
print(video_data)
\ No newline at end of file
import csv
import os
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader
from . import video_transforms
from .utils import center_crop_arr
# import video_transforms
# from utils import center_crop_arr
import json
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
# import ipdb
def get_transforms_video(resolution=256):
transform_video = transforms.Compose(
[
video_transforms.ToTensorVideo(), # TCHW
video_transforms.RandomHorizontalFlipVideo(),
video_transforms.UCFCenterCropVideo(resolution),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
return transform_video
def get_transforms_image(image_size=256):
transform = transforms.Compose(
[
transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, image_size)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
return transform
# open-sora-plan+magictime dataset
class DatasetFromCSV(torch.utils.data.Dataset):
"""load video according to the csv file.
Args:
target_video_len (int): the number of video frames will be load.
align_transform (callable): Align different videos in a specified size.
temporal_sample (callable): Sample the target length of a video.
"""
def __init__(
self,
csv_path,
num_frames=16,
frame_interval=1,
transform=None,
root=None,
):
video_samples = []
with open(csv_path, "r") as f:
reader = csv.reader(f)
csv_list = list(reader)
for v_s in csv_list[1:]: # no csv head
vid_path = v_s[0]
vid_caption = v_s[1]
if os.path.exists(vid_path):
video_samples.append([vid_path, vid_caption])
self.samples = video_samples
self.is_video = True
self.transform = transform
self.num_frames = num_frames
self.frame_interval = frame_interval
self.temporal_sample = video_transforms.TemporalRandomCrop(num_frames * frame_interval)
self.root = root
def getitem(self, index):
sample = self.samples[index]
path = sample[0]
text = sample[1]
if self.is_video:
is_exit = os.path.exists(path)
if is_exit:
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
total_frames = len(vframes)
else:
total_frames = 0
loop_index = index
while(total_frames < self.num_frames or is_exit == False):
#print("total_frames:", total_frames, "<", self.num_frames, ", or", path, "does not exit!!!")
loop_index += 1
if loop_index >= len(self.samples):
loop_index = 0
sample = self.samples[loop_index]
path = sample[0]
text = sample[1]
is_exit = os.path.exists(path)
if is_exit:
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
total_frames = len(vframes)
else:
total_frames = 0
# video exits and total_frames >= self.num_frames
# Sampling video frames
start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
assert (
end_frame_ind - start_frame_ind >= self.num_frames
), f"{path} with index {index} has not enough frames."
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int)
#print("total_frames:", total_frames, "frame_indice:", frame_indice, "sample:", sample)
video = vframes[frame_indice]
video = self.transform(video) # T C H W
else:
image = pil_loader(path)
image = self.transform(image)
video = image.unsqueeze(0).repeat(self.num_frames, 1, 1, 1)
# TCHW -> CTHW
video = video.permute(1, 0, 2, 3)
return {"video": video, "text": text}
def __getitem__(self, index):
for _ in range(10):
try:
return self.getitem(index)
except Exception as e:
print(e)
index = np.random.randint(len(self))
raise RuntimeError("Too many bad data.")
def __len__(self):
return len(self.samples)
if __name__ == '__main__':
data_path = '/mnt/bn/yh-volume0/dataset/CelebvHQ/CelebvHQ_caption_llava-34B.csv'
root='/mnt/bn/yh-volume0/dataset/CelebvHQ/35666'
dataset = DatasetFromCSV(
data_path,
transform=get_transforms_video(),
num_frames=16,
frame_interval=3,
root=root,
)
sampler = DistributedSampler(
dataset,
num_replicas=1,
rank=0,
shuffle=True,
seed=1
)
loader = DataLoader(
dataset,
batch_size=1,
shuffle=False,
sampler=sampler,
num_workers=0,
pin_memory=True,
drop_last=True
)
for video_data in loader:
print(video_data)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment