Commit 0e56f303 authored by mashun's avatar mashun
Browse files

pyramid-flow

parents
Pipeline #2007 canceled with stages
import os
import json
import torch
import time
import random
from typing import Iterable
from collections import OrderedDict
from PIL import Image
from torch.utils.data import Dataset, DataLoader, ConcatDataset, IterableDataset, DistributedSampler, RandomSampler
from torch.utils.data.dataloader import default_collate
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from torchvision.transforms import functional as F
from .bucket_loader import Bucketeer, TemporalLengthBucketeer
class IterLoader:
"""
A wrapper to convert DataLoader as an infinite iterator.
Modified from:
https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py
"""
def __init__(self, dataloader: DataLoader, use_distributed: bool = False, epoch: int = 0):
self._dataloader = dataloader
self.iter_loader = iter(self._dataloader)
self._use_distributed = use_distributed
self._epoch = epoch
@property
def epoch(self) -> int:
return self._epoch
def __next__(self):
try:
data = next(self.iter_loader)
except StopIteration:
self._epoch += 1
if hasattr(self._dataloader.sampler, "set_epoch") and self._use_distributed:
self._dataloader.sampler.set_epoch(self._epoch)
time.sleep(2) # Prevent possible deadlock during epoch transition
self.iter_loader = iter(self._dataloader)
data = next(self.iter_loader)
return data
def __iter__(self):
return self
def __len__(self):
return len(self._dataloader)
def identity(x):
return x
def create_image_text_dataloaders(dataset, batch_size, num_workers,
multi_aspect_ratio=True, epoch=0, sizes=[(512, 512), (384, 640), (640, 384)],
use_distributed=True, world_size=None, rank=None,
):
"""
The dataset has already been splited by different rank
"""
if use_distributed:
assert world_size is not None
assert rank is not None
sampler = DistributedSampler(
dataset,
shuffle=True,
num_replicas=world_size,
rank=rank,
seed=epoch,
)
else:
sampler = RandomSampler(dataset)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=True,
sampler=sampler,
collate_fn=identity if multi_aspect_ratio else default_collate,
drop_last=True,
)
if multi_aspect_ratio:
dataloader_iterator = Bucketeer(
dataloader,
sizes=sizes,
is_infinite=True, epoch=epoch,
)
else:
dataloader_iterator = iter(dataloader)
# To make it infinite
loader = IterLoader(dataloader_iterator, use_distributed=False, epoch=epoch)
return loader
def create_length_grouped_video_text_dataloader(dataset, batch_size, num_workers, max_frames,
world_size=None, rank=None, epoch=0, use_distributed=False):
if use_distributed:
assert world_size is not None
assert rank is not None
sampler = DistributedSampler(
dataset,
shuffle=True,
num_replicas=world_size,
rank=rank,
seed=epoch,
)
else:
sampler = RandomSampler(dataset)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=True,
sampler=sampler,
collate_fn=identity,
drop_last=True,
)
# make it infinite
dataloader_iterator = TemporalLengthBucketeer(
dataloader,
max_frames=max_frames,
epoch=epoch,
)
return dataloader_iterator
def create_mixed_dataloaders(
dataset, batch_size, num_workers, world_size=None, rank=None, epoch=0,
image_mix_ratio=0.1, use_image_video_mixed_training=True,
):
"""
The video & image mixed training dataloader builder
"""
assert world_size is not None
assert rank is not None
image_gpus = max(1, int(world_size * image_mix_ratio))
if use_image_video_mixed_training:
video_gpus = world_size - image_gpus
else:
# only use video data
video_gpus = world_size
image_gpus = 0
print(f"{image_gpus} gpus for image, {video_gpus} gpus for video")
if rank < video_gpus:
sampler = DistributedSampler(
dataset,
shuffle=True,
num_replicas=video_gpus,
rank=rank,
seed=epoch,
)
else:
sampler = DistributedSampler(
dataset,
shuffle=True,
num_replicas=image_gpus,
rank=rank - video_gpus,
seed=epoch,
)
loader = DataLoader(
dataset,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=True,
sampler=sampler,
collate_fn=default_collate,
drop_last=True,
)
# To make it infinite
loader = IterLoader(loader, use_distributed=True, epoch=epoch)
return loader
\ No newline at end of file
import os
import json
import jsonlines
import torch
import math
import random
import cv2
from tqdm import tqdm
from collections import OrderedDict
from PIL import Image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import numpy as np
import subprocess
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from torchvision.transforms import functional as F
class ImageTextDataset(Dataset):
"""
Usage:
The dataset class for image-text pairs, used for image generation training
It supports multi-aspect ratio training
params:
anno_file: The annotation file list
add_normalize: whether to normalize the input image pixel to [-1, 1], default: True
ratios: The aspect ratios during training, format: width / height
sizes: The resoultion of training images, format: (width, height)
"""
def __init__(
self, anno_file, add_normalize=True,
ratios=[1/1, 3/5, 5/3],
sizes=[(1024, 1024), (768, 1280), (1280, 768)],
crop_mode='random', p_random_ratio=0.0,
):
# Ratios and Sizes : (w h)
super().__init__()
self.image_annos = []
if not isinstance(anno_file, list):
anno_file = [anno_file]
for anno_file_ in anno_file:
print(f"Load image annotation files from {anno_file_}")
with jsonlines.open(anno_file_, 'r') as reader:
for item in reader:
self.image_annos.append(item)
print(f"Totally Remained {len(self.image_annos)} images")
transform_list = [
transforms.ToTensor(),
]
if add_normalize:
transform_list.append(transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
self.transform = transforms.Compose(transform_list)
print(f"Transform List is {transform_list}")
assert crop_mode in ['center', 'random']
self.crop_mode = crop_mode
self.ratios = ratios
self.sizes = sizes
self.p_random_ratio = p_random_ratio
def get_closest_size(self, x):
if self.p_random_ratio > 0 and np.random.rand() < self.p_random_ratio:
best_size_idx = np.random.randint(len(self.ratios))
else:
w, h = x.width, x.height
best_size_idx = np.argmin([abs(w/h-r) for r in self.ratios])
return self.sizes[best_size_idx]
def get_resize_size(self, orig_size, tgt_size):
if (tgt_size[1]/tgt_size[0] - 1) * (orig_size[1]/orig_size[0] - 1) >= 0:
alt_min = int(math.ceil(max(tgt_size)*min(orig_size)/max(orig_size)))
resize_size = max(alt_min, min(tgt_size))
else:
alt_max = int(math.ceil(min(tgt_size)*max(orig_size)/min(orig_size)))
resize_size = max(alt_max, max(tgt_size))
return resize_size
def __len__(self):
return len(self.image_annos)
def __getitem__(self, index):
image_anno = self.image_annos[index]
try:
img = Image.open(image_anno['image']).convert("RGB")
text = image_anno['text']
assert isinstance(text, str), "Text should be str"
size = self.get_closest_size(img)
resize_size = self.get_resize_size((img.width, img.height), size)
img = transforms.functional.resize(img, resize_size, interpolation=transforms.InterpolationMode.BICUBIC, antialias=True)
if self.crop_mode == 'center':
img = transforms.functional.center_crop(img, (size[1], size[0]))
elif self.crop_mode == 'random':
img = transforms.RandomCrop((size[1], size[0]))(img)
else:
img = transforms.functional.center_crop(img, (size[1], size[0]))
image_tensor = self.transform(img)
return {
"video": image_tensor, # using keyname `video`, to be compatible with video
"text" : text,
"identifier": 'image',
}
except Exception as e:
print(f'Load Image Error with {e}')
return self.__getitem__(random.randint(0, self.__len__() - 1))
class LengthGroupedVideoTextDataset(Dataset):
"""
Usage:
The dataset class for video-text pairs, used for video generation training
It groups the video with the same frames together
Now only supporting fixed resolution during training
params:
anno_file: The annotation file list
max_frames: The maximum temporal lengths (This is the vae latent temporal length) 16 => (16 - 1) * 8 + 1 = 121 frames
load_vae_latent: Loading the pre-extracted vae latents during training, we recommend to extract the latents in advance
to reduce the time cost per batch
load_text_fea: Loading the pre-extracted text features during training, we recommend to extract the prompt textual features
in advance, since the T5 encoder will cost many GPU memories
"""
def __init__(self, anno_file, max_frames=16, resolution='384p', load_vae_latent=True, load_text_fea=True):
super().__init__()
self.video_annos = []
self.max_frames = max_frames
self.load_vae_latent = load_vae_latent
self.load_text_fea = load_text_fea
self.resolution = resolution
assert load_vae_latent, "Now only support loading vae latents, we will support to directly load video frames in the future"
if not isinstance(anno_file, list):
anno_file = [anno_file]
for anno_file_ in anno_file:
with jsonlines.open(anno_file_, 'r') as reader:
for item in tqdm(reader):
self.video_annos.append(item)
print(f"Totally Remained {len(self.video_annos)} videos")
def __len__(self):
return len(self.video_annos)
def __getitem__(self, index):
try:
video_anno = self.video_annos[index]
text = video_anno['text']
latent_path = video_anno['latent']
latent = torch.load(latent_path, map_location='cpu') # loading the pre-extracted video latents
# TODO: remove the hard code latent shape checking
if self.resolution == '384p':
assert latent.shape[-1] == 640 // 8
assert latent.shape[-2] == 384 // 8
else:
assert self.resolution == '768p'
assert latent.shape[-1] == 1280 // 8
assert latent.shape[-2] == 768 // 8
cur_temp = latent.shape[2]
cur_temp = min(cur_temp, self.max_frames)
video_latent = latent[:,:,:cur_temp].float()
assert video_latent.shape[1] == 16
if self.load_text_fea:
text_fea_path = video_anno['text_fea']
text_fea = torch.load(text_fea_path, map_location='cpu')
return {
'video': video_latent,
'prompt_embed': text_fea['prompt_embed'],
'prompt_attention_mask': text_fea['prompt_attention_mask'],
'pooled_prompt_embed': text_fea['pooled_prompt_embed'],
"identifier": 'video',
}
else:
return {
'video': video_latent,
'text': text,
"identifier": 'video',
}
except Exception as e:
print(f'Load Video Error with {e}')
return self.__getitem__(random.randint(0, self.__len__() - 1))
class VideoFrameProcessor:
# load a video and transform
def __init__(self, resolution=256, num_frames=24, add_normalize=True, sample_fps=24):
image_size = resolution
transform_list = [
transforms.Resize(image_size, interpolation=InterpolationMode.BICUBIC, antialias=True),
transforms.CenterCrop(image_size),
]
if add_normalize:
transform_list.append(transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
print(f"Transform List is {transform_list}")
self.num_frames = num_frames
self.transform = transforms.Compose(transform_list)
self.sample_fps = sample_fps
def __call__(self, video_path):
try:
video_capture = cv2.VideoCapture(video_path)
fps = video_capture.get(cv2.CAP_PROP_FPS)
frames = []
while True:
flag, frame = video_capture.read()
if not flag:
break
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = torch.from_numpy(frame)
frame = frame.permute(2, 0, 1)
frames.append(frame)
video_capture.release()
sample_fps = self.sample_fps
interval = max(int(fps / sample_fps), 1)
frames = frames[::interval]
if len(frames) < self.num_frames:
num_frame_to_pack = self.num_frames - len(frames)
recurrent_num = num_frame_to_pack // len(frames)
frames = frames + recurrent_num * frames + frames[:(num_frame_to_pack % len(frames))]
assert len(frames) >= self.num_frames, f'{len(frames)}'
start_indexs = list(range(0, max(0, len(frames) - self.num_frames + 1)))
start_index = random.choice(start_indexs)
filtered_frames = frames[start_index : start_index+self.num_frames]
assert len(filtered_frames) == self.num_frames, f"The sampled frames should equals to {self.num_frames}"
filtered_frames = torch.stack(filtered_frames).float() / 255
filtered_frames = self.transform(filtered_frames)
filtered_frames = filtered_frames.permute(1, 0, 2, 3)
return filtered_frames, None
except Exception as e:
print(f"Load video: {video_path} Error, Exception {e}")
return None, None
class VideoDataset(Dataset):
def __init__(self, anno_file, resolution=256, max_frames=6, add_normalize=True):
super().__init__()
self.video_annos = []
self.max_frames = max_frames
if not isinstance(anno_file, list):
anno_file = [anno_file]
print(f"The training video clip frame number is {max_frames} ")
for anno_file_ in anno_file:
print(f"Load annotation file from {anno_file_}")
with jsonlines.open(anno_file_, 'r') as reader:
for item in tqdm(reader):
self.video_annos.append(item)
print(f"Totally Remained {len(self.video_annos)} videos")
self.video_processor = VideoFrameProcessor(resolution, max_frames, add_normalize)
def __len__(self):
return len(self.video_annos)
def __getitem__(self, index):
video_anno = self.video_annos[index]
video_path = video_anno['video']
try:
video_tensors, video_frames = self.video_processor(video_path)
assert video_tensors.shape[1] == self.max_frames
return {
"video": video_tensors,
"identifier": 'video',
}
except Exception as e:
print('Loading Video Error with {e}')
return self.__getitem__(random.randint(0, self.__len__() - 1))
class ImageDataset(Dataset):
def __init__(self, anno_file, resolution=256, max_frames=8, add_normalize=True):
super().__init__()
self.image_annos = []
self.max_frames = max_frames
image_paths = []
if not isinstance(anno_file, list):
anno_file = [anno_file]
for anno_file_ in anno_file:
print(f"Load annotation file from {anno_file_}")
with jsonlines.open(anno_file_, 'r') as reader:
for item in tqdm(reader):
image_paths.append(item['image'])
print(f"Totally Remained {len(image_paths)} images")
# pack multiple frames
for idx in range(0, len(image_paths), self.max_frames):
image_path_shard = image_paths[idx : idx + self.max_frames]
if len(image_path_shard) < self.max_frames:
image_path_shard = image_path_shard + image_paths[:self.max_frames - len(image_path_shard)]
assert len(image_path_shard) == self.max_frames
self.image_annos.append(image_path_shard)
image_size = resolution
transform_list = [
transforms.Resize(image_size, interpolation=InterpolationMode.BICUBIC, antialias=True),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
]
if add_normalize:
transform_list.append(transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
print(f"Transform List is {transform_list}")
self.transform = transforms.Compose(transform_list)
def __len__(self):
return len(self.image_annos)
def __getitem__(self, index):
image_paths = self.image_annos[index]
try:
packed_pil_frames = [Image.open(image_path).convert("RGB") for image_path in image_paths]
filtered_frames = [self.transform(frame) for frame in packed_pil_frames]
filtered_frames = torch.stack(filtered_frames) # [t, c, h, w]
filtered_frames = filtered_frames.permute(1, 0, 2, 3) # [c, t, h, w]
return {
"video": filtered_frames,
"identifier": 'image',
}
except Exception as e:
print(f'Load Images Error with {e}')
return self.__getitem__(random.randint(0, self.__len__() - 1))
\ No newline at end of file
from .scheduling_cosine_ddpm import DDPMCosineScheduler
from .scheduling_flow_matching import PyramidFlowMatchEulerDiscreteScheduler
\ No newline at end of file
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils import BaseOutput
from diffusers.utils.torch_utils import randn_tensor
from diffusers.schedulers.scheduling_utils import SchedulerMixin
@dataclass
class DDPMSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's step function output.
Args:
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
"""
prev_sample: torch.Tensor
class DDPMCosineScheduler(SchedulerMixin, ConfigMixin):
@register_to_config
def __init__(
self,
scaler: float = 1.0,
s: float = 0.008,
):
self.scaler = scaler
self.s = torch.tensor([s])
self._init_alpha_cumprod = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2
# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
def _alpha_cumprod(self, t, device):
if self.scaler > 1:
t = 1 - (1 - t) ** self.scaler
elif self.scaler < 1:
t = t**self.scaler
alpha_cumprod = torch.cos(
(t + self.s.to(device)) / (1 + self.s.to(device)) * torch.pi * 0.5
) ** 2 / self._init_alpha_cumprod.to(device)
return alpha_cumprod.clamp(0.0001, 0.9999)
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.Tensor`): input sample
timestep (`int`, optional): current timestep
Returns:
`torch.Tensor`: scaled input sample
"""
return sample
def set_timesteps(
self,
num_inference_steps: int = None,
timesteps: Optional[List[int]] = None,
device: Union[str, torch.device] = None,
):
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
Args:
num_inference_steps (`Dict[float, int]`):
the number of diffusion steps used when generating samples with a pre-trained model. If passed, then
`timesteps` must be `None`.
device (`str` or `torch.device`, optional):
the device to which the timesteps are moved to. {2 / 3: 20, 0.0: 10}
"""
if timesteps is None:
timesteps = torch.linspace(1.0, 0.0, num_inference_steps + 1, device=device)
if not isinstance(timesteps, torch.Tensor):
timesteps = torch.Tensor(timesteps).to(device)
self.timesteps = timesteps
def step(
self,
model_output: torch.Tensor,
timestep: int,
sample: torch.Tensor,
generator=None,
return_dict: bool = True,
) -> Union[DDPMSchedulerOutput, Tuple]:
dtype = model_output.dtype
device = model_output.device
t = timestep
prev_t = self.previous_timestep(t)
alpha_cumprod = self._alpha_cumprod(t, device).view(t.size(0), *[1 for _ in sample.shape[1:]])
alpha_cumprod_prev = self._alpha_cumprod(prev_t, device).view(prev_t.size(0), *[1 for _ in sample.shape[1:]])
alpha = alpha_cumprod / alpha_cumprod_prev
mu = (1.0 / alpha).sqrt() * (sample - (1 - alpha) * model_output / (1 - alpha_cumprod).sqrt())
std_noise = randn_tensor(mu.shape, generator=generator, device=model_output.device, dtype=model_output.dtype)
std = ((1 - alpha) * (1.0 - alpha_cumprod_prev) / (1.0 - alpha_cumprod)).sqrt() * std_noise
pred = mu + std * (prev_t != 0).float().view(prev_t.size(0), *[1 for _ in sample.shape[1:]])
if not return_dict:
return (pred.to(dtype),)
return DDPMSchedulerOutput(prev_sample=pred.to(dtype))
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
device = original_samples.device
dtype = original_samples.dtype
alpha_cumprod = self._alpha_cumprod(timesteps, device=device).view(
timesteps.size(0), *[1 for _ in original_samples.shape[1:]]
)
noisy_samples = alpha_cumprod.sqrt() * original_samples + (1 - alpha_cumprod).sqrt() * noise
return noisy_samples.to(dtype=dtype)
def __len__(self):
return self.config.num_train_timesteps
def previous_timestep(self, timestep):
index = (self.timesteps - timestep[0]).abs().argmin().item()
prev_t = self.timesteps[index + 1][None].expand(timestep.shape[0])
return prev_t
from dataclasses import dataclass
from typing import Optional, Tuple, Union, List
import math
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils import BaseOutput, logging
from diffusers.utils.torch_utils import randn_tensor
from diffusers.schedulers.scheduling_utils import SchedulerMixin
@dataclass
class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's `step` function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
"""
prev_sample: torch.FloatTensor
class PyramidFlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
Euler scheduler.
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
methods the library implements for all schedulers such as loading and saving.
Args:
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
timestep_spacing (`str`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
shift (`float`, defaults to 1.0):
The shift value for the timestep schedule.
"""
_compatibles = []
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
shift: float = 1.0, # Following Stable diffusion 3,
stages: int = 3,
stage_range: List = [0, 1/3, 2/3, 1],
gamma: float = 1/3,
):
self.timestep_ratios = {} # The timestep ratio for each stage
self.timesteps_per_stage = {} # The detailed timesteps per stage
self.sigmas_per_stage = {}
self.start_sigmas = {}
self.end_sigmas = {}
self.ori_start_sigmas = {}
# self.init_sigmas()
self.init_sigmas_for_each_stage()
self.sigma_min = self.sigmas[-1].item()
self.sigma_max = self.sigmas[0].item()
self.gamma = gamma
def init_sigmas(self):
"""
initialize the global timesteps and sigmas
"""
num_train_timesteps = self.config.num_train_timesteps
shift = self.config.shift
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
sigmas = timesteps / num_train_timesteps
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
self.timesteps = sigmas * num_train_timesteps
self._step_index = None
self._begin_index = None
self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
def init_sigmas_for_each_stage(self):
"""
Init the timesteps for each stage
"""
self.init_sigmas()
stage_distance = []
stages = self.config.stages
training_steps = self.config.num_train_timesteps
stage_range = self.config.stage_range
# Init the start and end point of each stage
for i_s in range(stages):
# To decide the start and ends point
start_indice = int(stage_range[i_s] * training_steps)
start_indice = max(start_indice, 0)
end_indice = int(stage_range[i_s+1] * training_steps)
end_indice = min(end_indice, training_steps)
start_sigma = self.sigmas[start_indice].item()
end_sigma = self.sigmas[end_indice].item() if end_indice < training_steps else 0.0
self.ori_start_sigmas[i_s] = start_sigma
if i_s != 0:
ori_sigma = 1 - start_sigma
gamma = self.config.gamma
corrected_sigma = (1 / (math.sqrt(1 + (1 / gamma)) * (1 - ori_sigma) + ori_sigma)) * ori_sigma
# corrected_sigma = 1 / (2 - ori_sigma) * ori_sigma
start_sigma = 1 - corrected_sigma
stage_distance.append(start_sigma - end_sigma)
self.start_sigmas[i_s] = start_sigma
self.end_sigmas[i_s] = end_sigma
# Determine the ratio of each stage according to flow length
tot_distance = sum(stage_distance)
for i_s in range(stages):
if i_s == 0:
start_ratio = 0.0
else:
start_ratio = sum(stage_distance[:i_s]) / tot_distance
if i_s == stages - 1:
end_ratio = 1.0
else:
end_ratio = sum(stage_distance[:i_s+1]) / tot_distance
self.timestep_ratios[i_s] = (start_ratio, end_ratio)
# Determine the timesteps and sigmas for each stage
for i_s in range(stages):
timestep_ratio = self.timestep_ratios[i_s]
timestep_max = self.timesteps[int(timestep_ratio[0] * training_steps)]
timestep_min = self.timesteps[min(int(timestep_ratio[1] * training_steps), training_steps - 1)]
timesteps = np.linspace(
timestep_max, timestep_min, training_steps + 1,
)
self.timesteps_per_stage[i_s] = timesteps[:-1] if isinstance(timesteps, torch.Tensor) else torch.from_numpy(timesteps[:-1])
stage_sigmas = np.linspace(
1, 0, training_steps + 1,
)
self.sigmas_per_stage[i_s] = torch.from_numpy(stage_sigmas[:-1])
@property
def step_index(self):
"""
The index counter for current timestep. It will increase 1 after each scheduler step.
"""
return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def _sigma_to_t(self, sigma):
return sigma * self.config.num_train_timesteps
def set_timesteps(self, num_inference_steps: int, stage_index: int, device: Union[str, torch.device] = None):
"""
Setting the timesteps and sigmas for each stage
"""
self.num_inference_steps = num_inference_steps
training_steps = self.config.num_train_timesteps
self.init_sigmas()
stage_timesteps = self.timesteps_per_stage[stage_index]
timestep_max = stage_timesteps[0].item()
timestep_min = stage_timesteps[-1].item()
timesteps = np.linspace(
timestep_max, timestep_min, num_inference_steps,
)
self.timesteps = torch.from_numpy(timesteps).to(device=device)
stage_sigmas = self.sigmas_per_stage[stage_index]
sigma_max = stage_sigmas[0].item()
sigma_min = stage_sigmas[-1].item()
ratios = np.linspace(
sigma_max, sigma_min, num_inference_steps
)
sigmas = torch.from_numpy(ratios).to(device=device)
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
self._step_index = None
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
indices = (schedule_timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
pos = 1 if len(indices) > 1 else 0
return indices[pos].item()
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def step(
self,
model_output: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
sample: torch.FloatTensor,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`):
The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
tuple.
Returns:
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
returned, otherwise a tuple is returned where the first element is the sample tensor.
"""
if (
isinstance(timestep, int)
or isinstance(timestep, torch.IntTensor)
or isinstance(timestep, torch.LongTensor)
):
raise ValueError(
(
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
" one of the `scheduler.timesteps` as a timestep."
),
)
if self.step_index is None:
self._step_index = 0
# Upcast to avoid precision issues when computing prev_sample
sample = sample.to(torch.float32)
sigma = self.sigmas[self.step_index]
sigma_next = self.sigmas[self.step_index + 1]
prev_sample = sample + (sigma_next - sigma) * model_output
# Cast sample back to model compatible dtype
prev_sample = prev_sample.to(model_output.dtype)
# upon completion increase step index by one
self._step_index += 1
if not return_dict:
return (prev_sample,)
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
def __len__(self):
return self.config.num_train_timesteps
# Pyramid Flow's DiT Finetuning Guide
This is the finetuning guide for the DiT in Pyramid Flow. We provide instructions for both autoregressive and non-autoregressive versions. The former is more research oriented and the latter is more stable (but less efficient without temporal pyramid). Please refer to [another document](https://github.com/jy0205/Pyramid-Flow/blob/main/docs/VAE) for VAE finetuning.
## Hardware Requirements
+ DiT finetuning: At least 8 A100 GPUs.
## Prepare the Dataset
The training dataset should be arranged into a json file, with `video`, `text` fields. Since the video vae latent extraction is very slow, we strongly recommend you to pre-extract the video vae latents to save the training time. We provide a video vae latent extraction script in folder `tools`. You can run it with the following command:
```bash
sh scripts/extract_vae_latent.sh
```
(optional) Since the T5 text encoder will cost a lot of GPU memory, pre-extract the text features will save the training memory. We also provide a text feature extraction script in folder `tools`. You can run it with the following command:
```bash
sh scripts/extract_text_feature.sh
```
The final training annotation json file should look like the following format:
```
{"video": video_path, "text": text prompt, "latent": extracted video vae latent, "text_fea": extracted text feature}
```
We provide the example json annotation files for [video](https://github.com/jy0205/Pyramid-Flow/blob/main/annotation/video_text.jsonl) and [image](https://github.com/jy0205/Pyramid-Flow/blob/main/annotation/image_text.jsonl)) training in the `annotation` folder. You can refer them to prepare your training dataset.
## Run Training
We provide two types of training scripts: (1) autoregressive video generation training with temporal pyramid. (2) Full-sequence diffusion training with pyramid-flow for both text-to-image and text-to-video training. This corresponds to the following two script files. Running these training scripts using at least 8 GPUs:
+ `scripts/train_pyramid_flow.sh`: The autoregressive video generation training with temporal pyramid.
```bash
sh scripts/train_pyramid_flow.sh
```
+ `scripts/train_pyramid_flow_without_ar.sh`: Using pyramid-flow for full-sequence diffusion training.
```bash
sh scripts/train_pyramid_flow_without_ar.sh
```
## Tips
+ For the 768p version, make sure to add the args: `--gradient_checkpointing`
+ Param `NUM_FRAMES` should be set to a multiple of 8
+ For the param `video_sync_group`, it indicates the number of process that accepts the same input video, used for temporal pyramid AR training. We recommend to set this value to 4, 8 or 16. (16 is better if you have more GPUs)
+ Make sure to set `NUM_FRAMES % VIDEO_SYNC_GROUP == 0`, `GPUS % VIDEO_SYNC_GROUP == 0`, and `BATCH_SIZE % 4 == 0`
# Pyramid Flow's VAE Training Guide
This is the training guide for a [MAGVIT-v2](https://arxiv.org/abs/2310.05737) like continuous 3D VAE, which should be quite flexible. Feel free to build your own video generative model on this part of VAE training code. Please refer to [another document](https://github.com/jy0205/Pyramid-Flow/blob/main/docs/DiT) for DiT finetuning.
## Hardware Requirements
+ VAE training: At least 8 A100 GPUs.
## Prepare the Dataset
The training of our causal video vae uses both image and video data. Both of them should be arranged into a json file, with `video` or `image` field. The final training annotation json file should look like the following format:
```
# For Video
{"video": video_path}
# For Image
{"image": image_path}
```
## Run Training
The causal video vae undergoes a two-stage training.
+ Stage-1: image and video mixed training
+ Stage-2: pure video training, using context parallel to load video with more video frames
The VAE training script is `scripts/train_causal_video_vae.sh`, run it as follows:
```bash
sh scripts/train_causal_video_vae.sh
```
We also provide a VAE demo `causal_video_vae_demo.ipynb` for image and video reconstruction.
## Tips
+ For stage-1, we use a mixed image and video training. Add the param `--use_image_video_mixed_training` to support the mixed training. We set the image ratio to 0.1 by default.
+ Set the `resolution` to 256 is enough for VAE training.
+ For stage-1, the `max_frames` is set to 17. It means we use 17 sampled video frames for training.
+ For stage-2, we open the param `use_context_parallel` to distribute long video frames to multiple GPUs. Make sure to set `GPUS % CONTEXT_SIZE == 0` and `NUM_FRAMES=17 * CONTEXT_SIZE + 1`
\ No newline at end of file
import os
import torch
def extract_vae_weights(vae_checkpoint_path: str,
save_path: str):
checkpoint = torch.load(vae_checkpoint_path)
weights = checkpoint['model']
new_weights = {}
for name, params in weights.items():
if "dis" in name or "prece" in name or "logvar" in name: continue
name = name.split(".", 1)[1]
new_weights[name] = params
torch.save(new_weights, save_path)
if __name__ == "__main__":
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument("--vae_checkpoint_path", type=str)
parser.add_argument("--save_path", type=str)
args = parser.parse_args()
extract_vae_weights(args.vae_checkpoint_path,
args.save_path)
import os
import json
from pathlib import Path
def get_mapping(path_list):
return {p.stem: str(p.resolve()) for p in path_list}
def generate_image_text(image_root: str,
prompt_root: str,
save_root: str):
image_root = Path(image_root)
prompt_root = Path(prompt_root)
image_path_list = [*image_root.glob("*.jpg"), *image_root.glob("*.png"), *image_root.glob("*.JPEG")]
prompt_path_list = [*prompt_root.glob("*.json")]
image_path_mapping = get_mapping(image_path_list)
prompt_path_mapping = get_mapping(prompt_path_list)
keys = set(image_path_mapping.keys()) & set(prompt_path_mapping.keys())
for key in keys:
with open(prompt_path_mapping[key], "r") as f:
text = json.loads(f.read().strip())['prompt']
tmp = {"image": image_path_mapping[key], "text": text}
with open(os.path.join(save_root, "image_text.jsonl"), "a") as f:
f.write(json.dumps(tmp, ensure_ascii=False) + '\n')
if __name__ == "__main__":
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument("--image_root", type=str)
parser.add_argument("--prompt_root", type=str)
parser.add_argument("--save_root", type=str)
args = parser.parse_args()
generate_image_text(args.image_root, args.prompt_root, args.save_root)
from pathlib import Path
from argparse import ArgumentParser
import json
from typing import Union
def generate_vae_annotation(data_root: str,
data_type: str,
save_path: str):
assert data_type in ['image', 'video']
data_root = Path(data_root)
if data_type == "video":
data_path_list = [*data_root.glob("*.mp4")]
elif data_type == "image":
data_path_list = [*data_root.glob("*.png"), *data_root.glob("*.jpeg"),
*data_root.glob("*.jpg"), *data_root.glob("*.JPEG")]
else:
raise NotImplemented
with open(save_path, "w") as f:
for data_path in data_path_list:
f.write(json.dumps({data_type: str(data_path.resolve())}, ensure_ascii=False) + '\n')
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument("--data_root", type=str)
parser.add_argument("--data_type", type=str)
parser.add_argument("--save_path", type=str)
args = parser.parse_args()
generate_vae_annotation(args.data_root,
args.data_type,
args.save_path)
\ No newline at end of file
# 提取给定VidGen_1M视频对应的文本,设置相应特征保存位置
import os
import json
from pathlib import Path
from argparse import ArgumentParser
def get_video_text(video_root,
caption_json_path,
save_root,
video_latent_root,
text_fea_root):
video_root = Path(video_root)
# 最多支持3级目录
video_path_list = [*video_root.glob("*.mp4"), *video_root.glob("*/*.mp4"), *video_root.glob("*/*/*.mp4")]
vid_path = {p.stem: str(p.resolve()) for p in video_path_list}
with open(caption_json_path, "r") as f:
captions = json.load(f)
vid_caption = {}
for d in captions:
vid_caption[d['vid']] = d['caption']
os.makedirs(video_latent_root, exist_ok=True)
os.makedirs(text_fea_root, exist_ok=True)
with open(os.path.join(save_root, "video_text.jsonl"), "w") as f:
for vid, vpath in vid_path.items():
text = vid_caption[vid]
latent_path = str(Path(os.path.join(video_latent_root, f"{vid}.pt")).resolve())
text_fea_path = str(Path(os.path.join(text_fea_root, f"{vid}-text.pt")).resolve())
f.write(json.dumps({"video": vpath, "text": text, "latent": latent_path, "text_fea": text_fea_path}, ensure_ascii=False) + '\n')
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--video_root", type=str)
parser.add_argument("-cjp", "--caption_json_path", type=str)
parser.add_argument("-sr", "--save_root", type=str)
parser.add_argument("-vlr", "--video_latent_root", type=str)
parser.add_argument("-tfr", "--text_fea_root", type=str)
args = parser.parse_args()
get_video_text(args.video_root,
args.caption_json_path,
args.save_root,
args.video_latent_root,
args.text_fea_root)
\ No newline at end of file
icon.png

68.4 KB

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import json\n",
"import torch\n",
"import numpy as np\n",
"import PIL\n",
"from PIL import Image\n",
"from IPython.display import HTML\n",
"from pyramid_dit import PyramidDiTForVideoGeneration\n",
"from IPython.display import Image as ipython_image\n",
"from diffusers.utils import load_image, export_to_video, export_to_gif"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"variant='diffusion_transformer_image' # For low resolution\n",
"model_name = \"pyramid_flux\"\n",
"\n",
"model_path = \"/home/jinyang06/models/pyramid-flow-miniflux\" # The downloaded checkpoint dir\n",
"model_dtype = 'bf16'\n",
"\n",
"device_id = 0\n",
"torch.cuda.set_device(device_id)\n",
"\n",
"model = PyramidDiTForVideoGeneration(\n",
" model_path,\n",
" model_dtype,\n",
" model_name=model_name,\n",
" model_variant=variant,\n",
")\n",
"\n",
"model.vae.to(\"cuda\")\n",
"model.dit.to(\"cuda\")\n",
"model.text_encoder.to(\"cuda\")\n",
"\n",
"model.vae.enable_tiling()\n",
"\n",
"if model_dtype == \"bf16\":\n",
" torch_dtype = torch.bfloat16 \n",
"elif model_dtype == \"fp16\":\n",
" torch_dtype = torch.float16\n",
"else:\n",
" torch_dtype = torch.float32"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Text-to-Image"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"prompt = \"shoulder and full head portrait of a beautiful 19 year old girl, brunette, smiling, stunning, highly detailed, glamour lighting, HDR, photorealistic, hyperrealism, octane render, unreal engine\"\n",
"\n",
"# now support 3 aspect ratios\n",
"resolution_dict = {\n",
" '1:1' : (1024, 1024),\n",
" '5:3' : (1280, 768),\n",
" '3:5' : (768, 1280),\n",
"}\n",
"\n",
"ratio = '1:1' # 1:1, 5:3, 3:5\n",
"\n",
"width, height = resolution_dict[ratio]\n",
"\n",
"\n",
"with torch.no_grad(), torch.cuda.amp.autocast(enabled=True if model_dtype != 'fp32' else False, dtype=torch_dtype):\n",
" images = model.generate(\n",
" prompt=prompt,\n",
" num_inference_steps=[20, 20, 20],\n",
" height=height,\n",
" width=width,\n",
" temp=1,\n",
" guidance_scale=9.0, \n",
" output_type=\"pil\",\n",
" save_memory=False, \n",
" )\n",
"\n",
"display(images[0])"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
import os
import torch
import sys
import argparse
import random
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from diffusers.utils import export_to_video
from pyramid_dit import PyramidDiTForVideoGeneration
from trainer_misc import init_distributed_mode, init_sequence_parallel_group
import PIL
from PIL import Image
def get_args():
parser = argparse.ArgumentParser('Pytorch Multi-process Script', add_help=False)
parser.add_argument('--model_name', default='pyramid_mmdit', type=str, help="The model name", choices=["pyramid_flux", "pyramid_mmdit"])
parser.add_argument('--model_dtype', default='bf16', type=str, help="The Model Dtype: bf16")
parser.add_argument('--model_path', default='/home/jinyang06/models/pyramid-flow', type=str, help='Set it to the downloaded checkpoint dir')
parser.add_argument('--variant', default='diffusion_transformer_768p', type=str,)
parser.add_argument('--task', default='t2v', type=str, choices=['i2v', 't2v'])
parser.add_argument('--temp', default=16, type=int, help='The generated latent num, num_frames = temp * 8 + 1')
parser.add_argument('--sp_group_size', default=2, type=int, help="The number of gpus used for inference, should be 2 or 4")
parser.add_argument('--sp_proc_num', default=-1, type=int, help="The number of process used for video training, default=-1 means using all process.")
return parser.parse_args()
def main():
args = get_args()
# setup DDP
init_distributed_mode(args)
assert args.world_size == args.sp_group_size, "The sequence parallel size should be DDP world size"
# Enable sequence parallel
init_sequence_parallel_group(args)
device = torch.device('cuda')
rank = args.rank
model_dtype = args.model_dtype
if args.model_name == "pyramid_flux":
assert args.variant != "diffusion_transformer_768p", "The pyramid_flux does not support high resolution now, \
we will release it after finishing training. You can modify the model_name to pyramid_mmdit to support 768p version generation"
model = PyramidDiTForVideoGeneration(
args.model_path,
model_dtype,
model_name=args.model_name,
model_variant=args.variant,
)
model.vae.to(device)
model.dit.to(device)
model.text_encoder.to(device)
model.vae.enable_tiling()
if model_dtype == "bf16":
torch_dtype = torch.bfloat16
elif model_dtype == "fp16":
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
# The video generation config
if args.variant == 'diffusion_transformer_768p':
width = 1280
height = 768
else:
assert args.variant == 'diffusion_transformer_384p'
width = 640
height = 384
if args.task == 't2v':
# prompt = "A movie trailer featuring the adventures of the 30 year old space man wearing a red wool knitted motorcycle helmet, blue sky, salt desert, cinematic style, shot on 35mm film, vivid colors"
prompt = "a cat on the moon, salt desert, cinematic style, shot on 35mm film, vivid colors"
with torch.no_grad(), torch.cuda.amp.autocast(enabled=True if model_dtype != 'fp32' else False, dtype=torch_dtype):
frames = model.generate(
prompt=prompt,
num_inference_steps=[20, 20, 20],
video_num_inference_steps=[10, 10, 10],
height=height,
width=width,
temp=args.temp,
guidance_scale=7.0, # The guidance for the first frame, set it to 7 for 384p variant
video_guidance_scale=5.0, # The guidance for the other video latent
output_type="pil",
save_memory=True, # If you have enough GPU memory, set it to `False` to improve vae decoding speed
cpu_offloading=False, # If OOM, set it to True to reduce memory usage
inference_multigpu=True,
)
if rank == 0:
export_to_video(frames, "./text_to_video_sample.mp4", fps=24)
else:
assert args.task == 'i2v'
image_path = 'assets/the_great_wall.jpg'
image = Image.open(image_path).convert("RGB")
image = image.resize((width, height))
prompt = "FPV flying over the Great Wall"
with torch.no_grad(), torch.cuda.amp.autocast(enabled=True if model_dtype != 'fp32' else False, dtype=torch_dtype):
frames = model.generate_i2v(
prompt=prompt,
input_image=image,
num_inference_steps=[10, 10, 10],
temp=args.temp,
video_guidance_scale=4.0,
output_type="pil",
save_memory=True, # If you have enough GPU memory, set it to `False` to improve vae decoding speed
cpu_offloading=False, # If OOM, set it to True to reduce memory usage
inference_multigpu=True,
)
if rank == 0:
export_to_video(frames, "./image_to_video_sample.mp4", fps=24)
torch.distributed.barrier()
if __name__ == "__main__":
main()
\ No newline at end of file
# 模型唯一标识
modelCode=1126
# 模型名称
modelName=pyramid-flow_pytorch
# 模型描述
modelDescription=快速高质量视频生成
# 应用场景
appScenario=训练,推理,aigc,电商,教育,广媒
# 框架类型
frameType=Pytorch
from .pyramid_dit_for_video_gen_pipeline import PyramidDiTForVideoGeneration
from .flux_modules import FluxSingleTransformerBlock, FluxTransformerBlock, FluxTextEncoderWithMask
from .mmdit_modules import JointTransformerBlock, SD3TextEncoderWithMask
\ No newline at end of file
from .modeling_pyramid_flux import PyramidFluxTransformer
from .modeling_text_encoder import FluxTextEncoderWithMask
from .modeling_flux_block import FluxSingleTransformerBlock, FluxTransformerBlock
\ No newline at end of file
import math
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from diffusers.models.activations import get_activation, FP32SiLU
def get_timestep_embedding(
timesteps: torch.Tensor,
embedding_dim: int,
flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
Args
timesteps (torch.Tensor):
a 1-D Tensor of N indices, one per batch element. These may be fractional.
embedding_dim (int):
the dimension of the output.
flip_sin_to_cos (bool):
Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
downscale_freq_shift (float):
Controls the delta between frequencies between dimensions
scale (float):
Scaling factor applied to the embeddings.
max_period (int):
Controls the maximum frequency of the embeddings
Returns
torch.Tensor: an [N x dim] Tensor of positional embeddings.
"""
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange(
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
# scale embeddings
emb = scale * emb
# concat sine and cosine embeddings
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
# flip sine and cosine embeddings
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
# zero pad
if embedding_dim % 2 == 1:
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
class Timesteps(nn.Module):
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
self.scale = scale
def forward(self, timesteps):
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
scale=self.scale,
)
return t_emb
class TimestepEmbedding(nn.Module):
def __init__(
self,
in_channels: int,
time_embed_dim: int,
act_fn: str = "silu",
out_dim: int = None,
post_act_fn: Optional[str] = None,
cond_proj_dim=None,
sample_proj_bias=True,
):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
if cond_proj_dim is not None:
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
else:
self.cond_proj = None
self.act = get_activation(act_fn)
if out_dim is not None:
time_embed_dim_out = out_dim
else:
time_embed_dim_out = time_embed_dim
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
if post_act_fn is None:
self.post_act = None
else:
self.post_act = get_activation(post_act_fn)
def forward(self, sample, condition=None):
if condition is not None:
sample = sample + self.cond_proj(condition)
sample = self.linear_1(sample)
if self.act is not None:
sample = self.act(sample)
sample = self.linear_2(sample)
if self.post_act is not None:
sample = self.post_act(sample)
return sample
class PixArtAlphaTextProjection(nn.Module):
"""
Projects caption embeddings. Also handles dropout for classifier-free guidance.
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
"""
def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"):
super().__init__()
if out_features is None:
out_features = hidden_size
self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
if act_fn == "gelu_tanh":
self.act_1 = nn.GELU(approximate="tanh")
elif act_fn == "silu":
self.act_1 = nn.SiLU()
elif act_fn == "silu_fp32":
self.act_1 = FP32SiLU()
else:
raise ValueError(f"Unknown activation function: {act_fn}")
self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True)
def forward(self, caption):
hidden_states = self.linear_1(caption)
hidden_states = self.act_1(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, pooled_projection_dim):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
def forward(self, timestep, guidance, pooled_projection):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
guidance_proj = self.time_proj(guidance)
guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) # (N, D)
time_guidance_emb = timesteps_emb + guidance_emb
pooled_projections = self.text_embedder(pooled_projection)
conditioning = time_guidance_emb + pooled_projections
return conditioning
class CombinedTimestepTextProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, pooled_projection_dim):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
def forward(self, timestep, pooled_projection):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
pooled_projections = self.text_embedder(pooled_projection)
conditioning = timesteps_emb + pooled_projections
return conditioning
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment