Unverified Commit 66eef9a6 authored by glide-the's avatar glide-the Committed by GitHub
Browse files

fix: CogVideox train dataset _preprocess_data crop video (#9574)



* Removed int8 to float32 conversion (`* 2.0 - 1.0`) from `train_transforms` as it caused image overexposure.

Added `_resize_for_rectangle_crop` function to enable video cropping functionality. The cropping mode can be configured via `video_reshape_mode`, supporting options: ['center', 'random', 'none'].

* The number 127.5 may experience precision loss during division operations.

* wandb request pil image Type

* Resizing bug

* del jupyter

* make style

* Update examples/cogvideo/README.md

* make style

---------

Co-authored-by: --unset <--unset>
Co-authored-by: default avatarAryan <aryan@huggingface.co>
parent 63a5c874
...@@ -180,6 +180,7 @@ Note that setting the `<ID_TOKEN>` is not necessary. From some limited experimen ...@@ -180,6 +180,7 @@ Note that setting the `<ID_TOKEN>` is not necessary. From some limited experimen
> [!TIP] > [!TIP]
> You can pass `--use_8bit_adam` to reduce the memory requirements of training. > You can pass `--use_8bit_adam` to reduce the memory requirements of training.
> You can pass `--video_reshape_mode` video cropping functionality, supporting options: ['center', 'random', 'none']. See [this](https://gist.github.com/glide-the/7658dbfd5f555be0a1a687a4139dba40) notebook for examples.
> [!IMPORTANT] > [!IMPORTANT]
> The following settings have been tested at the time of adding CogVideoX LoRA training support: > The following settings have been tested at the time of adding CogVideoX LoRA training support:
......
...@@ -21,7 +21,9 @@ import shutil ...@@ -21,7 +21,9 @@ import shutil
from pathlib import Path from pathlib import Path
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import numpy as np
import torch import torch
import torchvision.transforms as TT
import transformers import transformers
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
...@@ -29,12 +31,14 @@ from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration ...@@ -29,12 +31,14 @@ from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration
from huggingface_hub import create_repo, upload_folder from huggingface_hub import create_repo, upload_folder
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
from torchvision import transforms from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import resize
from tqdm.auto import tqdm from tqdm.auto import tqdm
from transformers import AutoTokenizer, T5EncoderModel, T5Tokenizer from transformers import AutoTokenizer, T5EncoderModel, T5Tokenizer
import diffusers import diffusers
from diffusers import AutoencoderKLCogVideoX, CogVideoXDPMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel from diffusers import AutoencoderKLCogVideoX, CogVideoXDPMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel
from diffusers.image_processor import VaeImageProcessor
from diffusers.models.embeddings import get_3d_rotary_pos_embed from diffusers.models.embeddings import get_3d_rotary_pos_embed
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
...@@ -214,6 +218,12 @@ def get_args(): ...@@ -214,6 +218,12 @@ def get_args():
default=720, default=720,
help="All input videos are resized to this width.", help="All input videos are resized to this width.",
) )
parser.add_argument(
"--video_reshape_mode",
type=str,
default="center",
help="All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']",
)
parser.add_argument("--fps", type=int, default=8, help="All input videos will be used at this FPS.") parser.add_argument("--fps", type=int, default=8, help="All input videos will be used at this FPS.")
parser.add_argument( parser.add_argument(
"--max_num_frames", type=int, default=49, help="All input videos will be truncated to these many frames." "--max_num_frames", type=int, default=49, help="All input videos will be truncated to these many frames."
...@@ -413,6 +423,7 @@ class VideoDataset(Dataset): ...@@ -413,6 +423,7 @@ class VideoDataset(Dataset):
video_column: str = "video", video_column: str = "video",
height: int = 480, height: int = 480,
width: int = 720, width: int = 720,
video_reshape_mode: str = "center",
fps: int = 8, fps: int = 8,
max_num_frames: int = 49, max_num_frames: int = 49,
skip_frames_start: int = 0, skip_frames_start: int = 0,
...@@ -429,6 +440,7 @@ class VideoDataset(Dataset): ...@@ -429,6 +440,7 @@ class VideoDataset(Dataset):
self.video_column = video_column self.video_column = video_column
self.height = height self.height = height
self.width = width self.width = width
self.video_reshape_mode = video_reshape_mode
self.fps = fps self.fps = fps
self.max_num_frames = max_num_frames self.max_num_frames = max_num_frames
self.skip_frames_start = skip_frames_start self.skip_frames_start = skip_frames_start
...@@ -532,6 +544,38 @@ class VideoDataset(Dataset): ...@@ -532,6 +544,38 @@ class VideoDataset(Dataset):
return instance_prompts, instance_videos return instance_prompts, instance_videos
def _resize_for_rectangle_crop(self, arr):
image_size = self.height, self.width
reshape_mode = self.video_reshape_mode
if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
arr = resize(
arr,
size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
interpolation=InterpolationMode.BICUBIC,
)
else:
arr = resize(
arr,
size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
interpolation=InterpolationMode.BICUBIC,
)
h, w = arr.shape[2], arr.shape[3]
arr = arr.squeeze(0)
delta_h = h - image_size[0]
delta_w = w - image_size[1]
if reshape_mode == "random" or reshape_mode == "none":
top = np.random.randint(0, delta_h + 1)
left = np.random.randint(0, delta_w + 1)
elif reshape_mode == "center":
top, left = delta_h // 2, delta_w // 2
else:
raise NotImplementedError
arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
return arr
def _preprocess_data(self): def _preprocess_data(self):
try: try:
import decord import decord
...@@ -542,15 +586,14 @@ class VideoDataset(Dataset): ...@@ -542,15 +586,14 @@ class VideoDataset(Dataset):
decord.bridge.set_bridge("torch") decord.bridge.set_bridge("torch")
videos = [] progress_dataset_bar = tqdm(
train_transforms = transforms.Compose( range(0, len(self.instance_video_paths)),
[ desc="Loading progress resize and crop videos",
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0),
]
) )
videos = []
for filename in self.instance_video_paths: for filename in self.instance_video_paths:
video_reader = decord.VideoReader(uri=filename.as_posix(), width=self.width, height=self.height) video_reader = decord.VideoReader(uri=filename.as_posix())
video_num_frames = len(video_reader) video_num_frames = len(video_reader)
start_frame = min(self.skip_frames_start, video_num_frames) start_frame = min(self.skip_frames_start, video_num_frames)
...@@ -576,10 +619,16 @@ class VideoDataset(Dataset): ...@@ -576,10 +619,16 @@ class VideoDataset(Dataset):
assert (selected_num_frames - 1) % 4 == 0 assert (selected_num_frames - 1) % 4 == 0
# Training transforms # Training transforms
frames = frames.float() frames = (frames - 127.5) / 127.5
frames = torch.stack([train_transforms(frame) for frame in frames], dim=0) frames = frames.permute(0, 3, 1, 2) # [F, C, H, W]
videos.append(frames.permute(0, 3, 1, 2).contiguous()) # [F, C, H, W] progress_dataset_bar.set_description(
f"Loading progress Resizing video from {frames.shape[2]}x{frames.shape[3]} to {self.height}x{self.width}"
)
frames = self._resize_for_rectangle_crop(frames)
videos.append(frames.contiguous()) # [F, C, H, W]
progress_dataset_bar.update(1)
progress_dataset_bar.close()
return videos return videos
...@@ -694,8 +743,13 @@ def log_validation( ...@@ -694,8 +743,13 @@ def log_validation(
videos = [] videos = []
for _ in range(args.num_validation_videos): for _ in range(args.num_validation_videos):
video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0] pt_images = pipe(**pipeline_args, generator=generator, output_type="pt").frames[0]
videos.append(video) pt_images = torch.stack([pt_images[i] for i in range(pt_images.shape[0])])
image_np = VaeImageProcessor.pt_to_numpy(pt_images)
image_pil = VaeImageProcessor.numpy_to_pil(image_np)
videos.append(image_pil)
for tracker in accelerator.trackers: for tracker in accelerator.trackers:
phase_name = "test" if is_final_validation else "validation" phase_name = "test" if is_final_validation else "validation"
...@@ -1171,6 +1225,7 @@ def main(args): ...@@ -1171,6 +1225,7 @@ def main(args):
video_column=args.video_column, video_column=args.video_column,
height=args.height, height=args.height,
width=args.width, width=args.width,
video_reshape_mode=args.video_reshape_mode,
fps=args.fps, fps=args.fps,
max_num_frames=args.max_num_frames, max_num_frames=args.max_num_frames,
skip_frames_start=args.skip_frames_start, skip_frames_start=args.skip_frames_start,
...@@ -1179,13 +1234,21 @@ def main(args): ...@@ -1179,13 +1234,21 @@ def main(args):
id_token=args.id_token, id_token=args.id_token,
) )
def encode_video(video): def encode_video(video, bar):
bar.update(1)
video = video.to(accelerator.device, dtype=vae.dtype).unsqueeze(0) video = video.to(accelerator.device, dtype=vae.dtype).unsqueeze(0)
video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
latent_dist = vae.encode(video).latent_dist latent_dist = vae.encode(video).latent_dist
return latent_dist return latent_dist
train_dataset.instance_videos = [encode_video(video) for video in train_dataset.instance_videos] progress_encode_bar = tqdm(
range(0, len(train_dataset.instance_videos)),
desc="Loading Encode videos",
)
train_dataset.instance_videos = [
encode_video(video, progress_encode_bar) for video in train_dataset.instance_videos
]
progress_encode_bar.close()
def collate_fn(examples): def collate_fn(examples):
videos = [example["instance_video"].sample() * vae.config.scaling_factor for example in examples] videos = [example["instance_video"].sample() * vae.config.scaling_factor for example in examples]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment