import os import re from typing import Iterator, Optional from torch.distributed import ProcessGroup import numpy as np import pandas as pd import requests import torch import cv2 import torchvision import torchvision.transforms as transforms from torch.utils.data import DataLoader, Dataset from PIL import Image from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader from torch.distributed.distributed_c10d import _get_default_group from torch.utils.data.distributed import DistributedSampler from torchvision.io import write_video from torchvision.utils import save_image import random from . import video_transforms from .wavelet_color_fix import adain_color_fix VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv") regex = re.compile( r"^(?:http|ftp)s?://" # http:// or https:// r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|" # domain... r"localhost|" # localhost... r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})" # ...or ip r"(?::\d+)?" # optional port r"(?:/?|[/?]\S+)$", re.IGNORECASE, ) def is_url(url): return re.match(regex, url) is not None def read_file(input_path): if input_path.endswith(".csv"): return pd.read_csv(input_path) elif input_path.endswith(".parquet"): return pd.read_parquet(input_path) else: raise NotImplementedError(f"Unsupported file format: {input_path}") def download_url(input_path): output_dir = "cache" if not os.path.exists(output_dir): os.makedirs(output_dir) base_name = os.path.basename(input_path) output_path = os.path.join(output_dir, base_name) img_data = requests.get(input_path).content with open(output_path, "wb") as handler: handler.write(img_data) print(f"URL {input_path} downloaded to {output_path}") return output_path def temporal_random_crop(vframes, num_frames, frame_interval): temporal_sample = video_transforms.TemporalRandomCrop(num_frames * frame_interval) total_frames = len(vframes) start_frame_ind, end_frame_ind = temporal_sample(total_frames) assert end_frame_ind - start_frame_ind >= num_frames frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, num_frames, dtype=int) video = vframes[frame_indice] return video def compute_bidirectional_optical_flow(video_frames): video_frames = video_frames.permute(0, 2, 3, 1).numpy() # T C H W -> T H W C T, H, W, _ = video_frames.shape bidirectional_flow = torch.zeros((2, T - 1, H, W)) for t in range(T - 1): prev_frame = cv2.cvtColor(video_frames[t], cv2.COLOR_RGB2GRAY) next_frame = cv2.cvtColor(video_frames[t + 1], cv2.COLOR_RGB2GRAY) # 计算前向光流 flow_forward = cv2.calcOpticalFlowFarneback(prev_frame, next_frame, None, 0.5, 3, 15, 3, 5, 1.2, 0) # 计算反向光流 flow_backward = cv2.calcOpticalFlowFarneback(next_frame, prev_frame, None, 0.5, 3, 15, 3, 5, 1.2, 0) # 合并前向和反向光流图 bidirectional_flow[:, t] = torch.from_numpy((flow_forward + flow_backward).reshape(2, H, W)) return bidirectional_flow # 定义模糊函数 def blur_video(video, kernel_size=(21, 21), sigma=21): """ 对视频的每一帧进行高斯模糊处理 Args: video (torch.Tensor): 输入视频,维度为 [T, C, H, W] kernel_size (tuple): 模糊核大小,默认为 (5, 5) sigma (float): 高斯核标准差,默认为 0 Returns: torch.Tensor: 处理后的视频 """ blurred_frames = [] for frame in video: # 转换成 numpy 格式,大小为 (H, W, C) frame_np = frame.permute(1, 2, 0).numpy() # 使用 OpenCV 进行高斯模糊处理 blurred_frame = cv2.GaussianBlur(frame_np, kernel_size, sigma) # 转换回 PyTorch 格式,大小为 (C, H, W) blurred_frame = torch.from_numpy(blurred_frame).permute(2, 0, 1) blurred_frames.append(blurred_frame) # 拼接处理后的帧成为视频,维度为 [T, C, H, W] return torch.stack(blurred_frames) def get_transforms_video(name="center", image_size=(256, 256)): if name is None: return None elif name == "center": assert image_size[0] == image_size[1], "image_size must be square for center crop" transform_video = transforms.Compose( [ video_transforms.ToTensorVideo(), # TCHW # video_transforms.RandomHorizontalFlipVideo(), video_transforms.UCFCenterCropVideo(image_size[0]), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), ] ) elif name == "resize_crop": transform_video = transforms.Compose( [ video_transforms.ToTensorVideo(), # TCHW video_transforms.ResizeCrop(image_size), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), ] ) elif name == "direct_crop": transform_video = transforms.Compose( [ video_transforms.ToTensorVideo(), # TCHW video_transforms.RandomCrop(image_size), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), ] ) else: raise NotImplementedError(f"Transform {name} not implemented") return transform_video def get_transforms_image(name="center", image_size=(256, 256)): if name is None: return None elif name == "center": assert image_size[0] == image_size[1], "Image size must be square for center crop" transform = transforms.Compose( [ transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, image_size[0])), # transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), ] ) elif name == "resize_crop": transform = transforms.Compose( [ transforms.Lambda(lambda pil_image: resize_crop_to_fill(pil_image, image_size)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), ] ) else: raise NotImplementedError(f"Transform {name} not implemented") return transform def read_image_from_path(path, transform=None, transform_name="center", num_frames=1, image_size=(256, 256)): image = pil_loader(path) if transform is None: transform = get_transforms_image(image_size=image_size, name=transform_name) image = transform(image) video = image.unsqueeze(0).repeat(num_frames, 1, 1, 1) video = video.permute(1, 0, 2, 3) return video def read_video_from_path(path, transform=None, transform_name="center", image_size=(256, 256)): vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW") if transform is None: transform = get_transforms_video(image_size=image_size, name=transform_name) video = transform(vframes) # T C H W video = video.permute(1, 0, 2, 3) return video def read_from_path(path, image_size, transform_name="center"): if is_url(path): path = download_url(path) ext = os.path.splitext(path)[-1].lower() if ext.lower() in VID_EXTENSIONS: return read_video_from_path(path, image_size=image_size, transform_name=transform_name) else: assert ext.lower() in IMG_EXTENSIONS, f"Unsupported file format: {ext}" return read_image_from_path(path, image_size=image_size, transform_name=transform_name) def save_sample(x, fps=8, save_path=None, normalize=True, value_range=(-1, 1), force_video=False, align_method=None, validation_video=None): os.makedirs(os.path.dirname(save_path), exist_ok=True) """ Args: x (Tensor): shape [C, T, H, W] """ assert x.ndim == 4 if not force_video and x.shape[1] == 1: # T = 1: save as image save_path += ".png" x = x.squeeze(1) save_image([x], save_path, normalize=normalize, value_range=value_range) else: save_path += ".mp4" if normalize: low, high = value_range x.clamp_(min=low, max=high) x.sub_(low).div_(max(high - low, 1e-5)) if align_method: x = adain_color_fix(x, validation_video) x = x.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 3, 0).to("cpu", torch.uint8) write_video(save_path, x, fps=int(fps), video_codec="h264") # print(f"Saved to {save_path}") return save_path def center_crop_arr(pil_image, image_size): """ Center cropping implementation from ADM. https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 """ while min(*pil_image.size) >= 2 * image_size: pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX) scale = image_size / min(*pil_image.size) pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC) arr = np.array(pil_image) crop_y = (arr.shape[0] - image_size) // 2 crop_x = (arr.shape[1] - image_size) // 2 return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]) class StatefulDistributedSampler(DistributedSampler): def __init__( self, dataset: Dataset, num_replicas: Optional[int] = None, rank: Optional[int] = None, shuffle: bool = True, seed: int = 0, drop_last: bool = False, ) -> None: super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last) self.start_index: int = 0 def __iter__(self) -> Iterator: iterator = super().__iter__() indices = list(iterator) indices = indices[self.start_index :] return iter(indices) def __len__(self) -> int: return self.num_samples - self.start_index def set_start_index(self, start_index: int) -> None: self.start_index = start_index def prepare_dataloader( dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, process_group: Optional[ProcessGroup] = None, **kwargs, ): r""" Prepare a dataloader for distributed training. The dataloader will be wrapped by `torch.utils.data.DataLoader` and `StatefulDistributedSampler`. Args: dataset (`torch.utils.data.Dataset`): The dataset to be loaded. shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. seed (int, optional): Random worker seed for sampling, defaults to 1024. add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller, defaults to False. pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in `DataLoader `_. Returns: :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. """ _kwargs = kwargs.copy() process_group = process_group or _get_default_group() sampler = StatefulDistributedSampler( dataset, num_replicas=process_group.size(), rank=process_group.rank(), shuffle=shuffle ) # Deterministic dataloader def seed_worker(worker_id): worker_seed = seed np.random.seed(worker_seed) torch.manual_seed(worker_seed) random.seed(worker_seed) return DataLoader( dataset, batch_size=batch_size, sampler=sampler, worker_init_fn=seed_worker, drop_last=drop_last, pin_memory=pin_memory, num_workers=num_workers, **_kwargs, )