Commit 0e56f303 authored by mashun's avatar mashun
Browse files

pyramid-flow

parents
Pipeline #2007 canceled with stages
import os
import json
import torch
import time
import random
from typing import Iterable
from collections import OrderedDict
from PIL import Image
from torch.utils.data import Dataset, DataLoader, ConcatDataset, IterableDataset, DistributedSampler, RandomSampler
from torch.utils.data.dataloader import default_collate
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from torchvision.transforms import functional as F
from .bucket_loader import Bucketeer, TemporalLengthBucketeer
class IterLoader:
"""
A wrapper to convert DataLoader as an infinite iterator.
Modified from:
https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py
"""
def __init__(self, dataloader: DataLoader, use_distributed: bool = False, epoch: int = 0):
self._dataloader = dataloader
self.iter_loader = iter(self._dataloader)
self._use_distributed = use_distributed
self._epoch = epoch
@property
def epoch(self) -> int:
return self._epoch
def __next__(self):
try:
data = next(self.iter_loader)
except StopIteration:
self._epoch += 1
if hasattr(self._dataloader.sampler, "set_epoch") and self._use_distributed:
self._dataloader.sampler.set_epoch(self._epoch)
time.sleep(2) # Prevent possible deadlock during epoch transition
self.iter_loader = iter(self._dataloader)
data = next(self.iter_loader)
return data
def __iter__(self):
return self
def __len__(self):
return len(self._dataloader)
def identity(x):
return x
def create_image_text_dataloaders(dataset, batch_size, num_workers,
multi_aspect_ratio=True, epoch=0, sizes=[(512, 512), (384, 640), (640, 384)],
use_distributed=True, world_size=None, rank=None,
):
"""
The dataset has already been splited by different rank
"""
if use_distributed:
assert world_size is not None
assert rank is not None
sampler = DistributedSampler(
dataset,
shuffle=True,
num_replicas=world_size,
rank=rank,
seed=epoch,
)
else:
sampler = RandomSampler(dataset)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=True,
sampler=sampler,
collate_fn=identity if multi_aspect_ratio else default_collate,
drop_last=True,
)
if multi_aspect_ratio:
dataloader_iterator = Bucketeer(
dataloader,
sizes=sizes,
is_infinite=True, epoch=epoch,
)
else:
dataloader_iterator = iter(dataloader)
# To make it infinite
loader = IterLoader(dataloader_iterator, use_distributed=False, epoch=epoch)
return loader
def create_length_grouped_video_text_dataloader(dataset, batch_size, num_workers, max_frames,
world_size=None, rank=None, epoch=0, use_distributed=False):
if use_distributed:
assert world_size is not None
assert rank is not None
sampler = DistributedSampler(
dataset,
shuffle=True,
num_replicas=world_size,
rank=rank,
seed=epoch,
)
else:
sampler = RandomSampler(dataset)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=True,
sampler=sampler,
collate_fn=identity,
drop_last=True,
)
# make it infinite
dataloader_iterator = TemporalLengthBucketeer(
dataloader,
max_frames=max_frames,
epoch=epoch,
)
return dataloader_iterator
def create_mixed_dataloaders(
dataset, batch_size, num_workers, world_size=None, rank=None, epoch=0,
image_mix_ratio=0.1, use_image_video_mixed_training=True,
):
"""
The video & image mixed training dataloader builder
"""
assert world_size is not None
assert rank is not None
image_gpus = max(1, int(world_size * image_mix_ratio))
if use_image_video_mixed_training:
video_gpus = world_size - image_gpus
else:
# only use video data
video_gpus = world_size
image_gpus = 0
print(f"{image_gpus} gpus for image, {video_gpus} gpus for video")
if rank < video_gpus:
sampler = DistributedSampler(
dataset,
shuffle=True,
num_replicas=video_gpus,
rank=rank,
seed=epoch,
)
else:
sampler = DistributedSampler(
dataset,
shuffle=True,
num_replicas=image_gpus,
rank=rank - video_gpus,
seed=epoch,
)
loader = DataLoader(
dataset,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=True,
sampler=sampler,
collate_fn=default_collate,
drop_last=True,
)
# To make it infinite
loader = IterLoader(loader, use_distributed=True, epoch=epoch)
return loader
\ No newline at end of file
import os
import json
import jsonlines
import torch
import math
import random
import cv2
from tqdm import tqdm
from collections import OrderedDict
from PIL import Image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import numpy as np
import subprocess
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from torchvision.transforms import functional as F
class ImageTextDataset(Dataset):
"""
Usage:
The dataset class for image-text pairs, used for image generation training
It supports multi-aspect ratio training
params:
anno_file: The annotation file list
add_normalize: whether to normalize the input image pixel to [-1, 1], default: True
ratios: The aspect ratios during training, format: width / height
sizes: The resoultion of training images, format: (width, height)
"""
def __init__(
self, anno_file, add_normalize=True,
ratios=[1/1, 3/5, 5/3],
sizes=[(1024, 1024), (768, 1280), (1280, 768)],
crop_mode='random', p_random_ratio=0.0,
):
# Ratios and Sizes : (w h)
super().__init__()
self.image_annos = []
if not isinstance(anno_file, list):
anno_file = [anno_file]
for anno_file_ in anno_file:
print(f"Load image annotation files from {anno_file_}")
with jsonlines.open(anno_file_, 'r') as reader:
for item in reader:
self.image_annos.append(item)
print(f"Totally Remained {len(self.image_annos)} images")
transform_list = [
transforms.ToTensor(),
]
if add_normalize:
transform_list.append(transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
self.transform = transforms.Compose(transform_list)
print(f"Transform List is {transform_list}")
assert crop_mode in ['center', 'random']
self.crop_mode = crop_mode
self.ratios = ratios
self.sizes = sizes
self.p_random_ratio = p_random_ratio
def get_closest_size(self, x):
if self.p_random_ratio > 0 and np.random.rand() < self.p_random_ratio:
best_size_idx = np.random.randint(len(self.ratios))
else:
w, h = x.width, x.height
best_size_idx = np.argmin([abs(w/h-r) for r in self.ratios])
return self.sizes[best_size_idx]
def get_resize_size(self, orig_size, tgt_size):
if (tgt_size[1]/tgt_size[0] - 1) * (orig_size[1]/orig_size[0] - 1) >= 0:
alt_min = int(math.ceil(max(tgt_size)*min(orig_size)/max(orig_size)))
resize_size = max(alt_min, min(tgt_size))
else:
alt_max = int(math.ceil(min(tgt_size)*max(orig_size)/min(orig_size)))
resize_size = max(alt_max, max(tgt_size))
return resize_size
def __len__(self):
return len(self.image_annos)
def __getitem__(self, index):
image_anno = self.image_annos[index]
try:
img = Image.open(image_anno['image']).convert("RGB")
text = image_anno['text']
assert isinstance(text, str), "Text should be str"
size = self.get_closest_size(img)
resize_size = self.get_resize_size((img.width, img.height), size)
img = transforms.functional.resize(img, resize_size, interpolation=transforms.InterpolationMode.BICUBIC, antialias=True)
if self.crop_mode == 'center':
img = transforms.functional.center_crop(img, (size[1], size[0]))
elif self.crop_mode == 'random':
img = transforms.RandomCrop((size[1], size[0]))(img)
else:
img = transforms.functional.center_crop(img, (size[1], size[0]))
image_tensor = self.transform(img)
return {
"video": image_tensor, # using keyname `video`, to be compatible with video
"text" : text,
"identifier": 'image',
}
except Exception as e:
print(f'Load Image Error with {e}')
return self.__getitem__(random.randint(0, self.__len__() - 1))
class LengthGroupedVideoTextDataset(Dataset):
"""
Usage:
The dataset class for video-text pairs, used for video generation training
It groups the video with the same frames together
Now only supporting fixed resolution during training
params:
anno_file: The annotation file list
max_frames: The maximum temporal lengths (This is the vae latent temporal length) 16 => (16 - 1) * 8 + 1 = 121 frames
load_vae_latent: Loading the pre-extracted vae latents during training, we recommend to extract the latents in advance
to reduce the time cost per batch
load_text_fea: Loading the pre-extracted text features during training, we recommend to extract the prompt textual features
in advance, since the T5 encoder will cost many GPU memories
"""
def __init__(self, anno_file, max_frames=16, resolution='384p', load_vae_latent=True, load_text_fea=True):
super().__init__()
self.video_annos = []
self.max_frames = max_frames
self.load_vae_latent = load_vae_latent
self.load_text_fea = load_text_fea
self.resolution = resolution
assert load_vae_latent, "Now only support loading vae latents, we will support to directly load video frames in the future"
if not isinstance(anno_file, list):
anno_file = [anno_file]
for anno_file_ in anno_file:
with jsonlines.open(anno_file_, 'r') as reader:
for item in tqdm(reader):
self.video_annos.append(item)
print(f"Totally Remained {len(self.video_annos)} videos")
def __len__(self):
return len(self.video_annos)
def __getitem__(self, index):
try:
video_anno = self.video_annos[index]
text = video_anno['text']
latent_path = video_anno['latent']
latent = torch.load(latent_path, map_location='cpu') # loading the pre-extracted video latents
# TODO: remove the hard code latent shape checking
if self.resolution == '384p':
assert latent.shape[-1] == 640 // 8
assert latent.shape[-2] == 384 // 8
else:
assert self.resolution == '768p'
assert latent.shape[-1] == 1280 // 8
assert latent.shape[-2] == 768 // 8
cur_temp = latent.shape[2]
cur_temp = min(cur_temp, self.max_frames)
video_latent = latent[:,:,:cur_temp].float()
assert video_latent.shape[1] == 16
if self.load_text_fea:
text_fea_path = video_anno['text_fea']
text_fea = torch.load(text_fea_path, map_location='cpu')
return {
'video': video_latent,
'prompt_embed': text_fea['prompt_embed'],
'prompt_attention_mask': text_fea['prompt_attention_mask'],
'pooled_prompt_embed': text_fea['pooled_prompt_embed'],
"identifier": 'video',
}
else:
return {
'video': video_latent,
'text': text,
"identifier": 'video',
}
except Exception as e:
print(f'Load Video Error with {e}')
return self.__getitem__(random.randint(0, self.__len__() - 1))
class VideoFrameProcessor:
# load a video and transform
def __init__(self, resolution=256, num_frames=24, add_normalize=True, sample_fps=24):
image_size = resolution
transform_list = [
transforms.Resize(image_size, interpolation=InterpolationMode.BICUBIC, antialias=True),
transforms.CenterCrop(image_size),
]
if add_normalize:
transform_list.append(transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
print(f"Transform List is {transform_list}")
self.num_frames = num_frames
self.transform = transforms.Compose(transform_list)
self.sample_fps = sample_fps
def __call__(self, video_path):
try:
video_capture = cv2.VideoCapture(video_path)
fps = video_capture.get(cv2.CAP_PROP_FPS)
frames = []
while True:
flag, frame = video_capture.read()
if not flag:
break
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = torch.from_numpy(frame)
frame = frame.permute(2, 0, 1)
frames.append(frame)
video_capture.release()
sample_fps = self.sample_fps
interval = max(int(fps / sample_fps), 1)
frames = frames[::interval]
if len(frames) < self.num_frames:
num_frame_to_pack = self.num_frames - len(frames)
recurrent_num = num_frame_to_pack // len(frames)
frames = frames + recurrent_num * frames + frames[:(num_frame_to_pack % len(frames))]
assert len(frames) >= self.num_frames, f'{len(frames)}'
start_indexs = list(range(0, max(0, len(frames) - self.num_frames + 1)))
start_index = random.choice(start_indexs)
filtered_frames = frames[start_index : start_index+self.num_frames]
assert len(filtered_frames) == self.num_frames, f"The sampled frames should equals to {self.num_frames}"
filtered_frames = torch.stack(filtered_frames).float() / 255
filtered_frames = self.transform(filtered_frames)
filtered_frames = filtered_frames.permute(1, 0, 2, 3)
return filtered_frames, None
except Exception as e:
print(f"Load video: {video_path} Error, Exception {e}")
return None, None
class VideoDataset(Dataset):
def __init__(self, anno_file, resolution=256, max_frames=6, add_normalize=True):
super().__init__()
self.video_annos = []
self.max_frames = max_frames
if not isinstance(anno_file, list):
anno_file = [anno_file]
print(f"The training video clip frame number is {max_frames} ")
for anno_file_ in anno_file:
print(f"Load annotation file from {anno_file_}")
with jsonlines.open(anno_file_, 'r') as reader:
for item in tqdm(reader):
self.video_annos.append(item)
print(f"Totally Remained {len(self.video_annos)} videos")
self.video_processor = VideoFrameProcessor(resolution, max_frames, add_normalize)
def __len__(self):
return len(self.video_annos)
def __getitem__(self, index):
video_anno = self.video_annos[index]
video_path = video_anno['video']
try:
video_tensors, video_frames = self.video_processor(video_path)
assert video_tensors.shape[1] == self.max_frames
return {
"video": video_tensors,
"identifier": 'video',
}
except Exception as e:
print('Loading Video Error with {e}')
return self.__getitem__(random.randint(0, self.__len__() - 1))
class ImageDataset(Dataset):
def __init__(self, anno_file, resolution=256, max_frames=8, add_normalize=True):
super().__init__()
self.image_annos = []
self.max_frames = max_frames
image_paths = []
if not isinstance(anno_file, list):
anno_file = [anno_file]
for anno_file_ in anno_file:
print(f"Load annotation file from {anno_file_}")
with jsonlines.open(anno_file_, 'r') as reader:
for item in tqdm(reader):
image_paths.append(item['image'])
print(f"Totally Remained {len(image_paths)} images")
# pack multiple frames
for idx in range(0, len(image_paths), self.max_frames):
image_path_shard = image_paths[idx : idx + self.max_frames]
if len(image_path_shard) < self.max_frames:
image_path_shard = image_path_shard + image_paths[:self.max_frames - len(image_path_shard)]
assert len(image_path_shard) == self.max_frames
self.image_annos.append(image_path_shard)
image_size = resolution
transform_list = [
transforms.Resize(image_size, interpolation=InterpolationMode.BICUBIC, antialias=True),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
]
if add_normalize:
transform_list.append(transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
print(f"Transform List is {transform_list}")
self.transform = transforms.Compose(transform_list)
def __len__(self):
return len(self.image_annos)
def __getitem__(self, index):
image_paths = self.image_annos[index]
try:
packed_pil_frames = [Image.open(image_path).convert("RGB") for image_path in image_paths]
filtered_frames = [self.transform(frame) for frame in packed_pil_frames]
filtered_frames = torch.stack(filtered_frames) # [t, c, h, w]
filtered_frames = filtered_frames.permute(1, 0, 2, 3) # [c, t, h, w]
return {
"video": filtered_frames,
"identifier": 'image',
}
except Exception as e:
print(f'Load Images Error with {e}')
return self.__getitem__(random.randint(0, self.__len__() - 1))
\ No newline at end of file
from .scheduling_cosine_ddpm import DDPMCosineScheduler
from .scheduling_flow_matching import PyramidFlowMatchEulerDiscreteScheduler
\ No newline at end of file
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils import BaseOutput
from diffusers.utils.torch_utils import randn_tensor
from diffusers.schedulers.scheduling_utils import SchedulerMixin
@dataclass
class DDPMSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's step function output.
Args:
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
"""
prev_sample: torch.Tensor
class DDPMCosineScheduler(SchedulerMixin, ConfigMixin):
@register_to_config
def __init__(
self,
scaler: float = 1.0,
s: float = 0.008,
):
self.scaler = scaler
self.s = torch.tensor([s])
self._init_alpha_cumprod = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2
# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
def _alpha_cumprod(self, t, device):
if self.scaler > 1:
t = 1 - (1 - t) ** self.scaler
elif self.scaler < 1:
t = t**self.scaler
alpha_cumprod = torch.cos(
(t + self.s.to(device)) / (1 + self.s.to(device)) * torch.pi * 0.5
) ** 2 / self._init_alpha_cumprod.to(device)
return alpha_cumprod.clamp(0.0001, 0.9999)
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.Tensor`): input sample
timestep (`int`, optional): current timestep
Returns:
`torch.Tensor`: scaled input sample
"""
return sample
def set_timesteps(
self,
num_inference_steps: int = None,
timesteps: Optional[List[int]] = None,
device: Union[str, torch.device] = None,
):
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
Args:
num_inference_steps (`Dict[float, int]`):
the number of diffusion steps used when generating samples with a pre-trained model. If passed, then
`timesteps` must be `None`.
device (`str` or `torch.device`, optional):
the device to which the timesteps are moved to. {2 / 3: 20, 0.0: 10}
"""
if timesteps is None:
timesteps = torch.linspace(1.0, 0.0, num_inference_steps + 1, device=device)
if not isinstance(timesteps, torch.Tensor):
timesteps = torch.Tensor(timesteps).to(device)
self.timesteps = timesteps
def step(
self,
model_output: torch.Tensor,
timestep: int,
sample: torch.Tensor,
generator=None,
return_dict: bool = True,
) -> Union[DDPMSchedulerOutput, Tuple]:
dtype = model_output.dtype
device = model_output.device
t = timestep
prev_t = self.previous_timestep(t)
alpha_cumprod = self._alpha_cumprod(t, device).view(t.size(0), *[1 for _ in sample.shape[1:]])
alpha_cumprod_prev = self._alpha_cumprod(prev_t, device).view(prev_t.size(0), *[1 for _ in sample.shape[1:]])
alpha = alpha_cumprod / alpha_cumprod_prev
mu = (1.0 / alpha).sqrt() * (sample - (1 - alpha) * model_output / (1 - alpha_cumprod).sqrt())
std_noise = randn_tensor(mu.shape, generator=generator, device=model_output.device, dtype=model_output.dtype)
std = ((1 - alpha) * (1.0 - alpha_cumprod_prev) / (1.0 - alpha_cumprod)).sqrt() * std_noise
pred = mu + std * (prev_t != 0).float().view(prev_t.size(0), *[1 for _ in sample.shape[1:]])
if not return_dict:
return (pred.to(dtype),)
return DDPMSchedulerOutput(prev_sample=pred.to(dtype))
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
device = original_samples.device
dtype = original_samples.dtype
alpha_cumprod = self._alpha_cumprod(timesteps, device=device).view(
timesteps.size(0), *[1 for _ in original_samples.shape[1:]]
)
noisy_samples = alpha_cumprod.sqrt() * original_samples + (1 - alpha_cumprod).sqrt() * noise
return noisy_samples.to(dtype=dtype)
def __len__(self):
return self.config.num_train_timesteps
def previous_timestep(self, timestep):
index = (self.timesteps - timestep[0]).abs().argmin().item()
prev_t = self.timesteps[index + 1][None].expand(timestep.shape[0])
return prev_t
from dataclasses import dataclass
from typing import Optional, Tuple, Union, List
import math
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils import BaseOutput, logging
from diffusers.utils.torch_utils import randn_tensor
from diffusers.schedulers.scheduling_utils import SchedulerMixin
@dataclass
class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's `step` function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
"""
prev_sample: torch.FloatTensor
class PyramidFlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
Euler scheduler.
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
methods the library implements for all schedulers such as loading and saving.
Args:
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
timestep_spacing (`str`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
shift (`float`, defaults to 1.0):
The shift value for the timestep schedule.
"""
_compatibles = []
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
shift: float = 1.0, # Following Stable diffusion 3,
stages: int = 3,
stage_range: List = [0, 1/3, 2/3, 1],
gamma: float = 1/3,
):
self.timestep_ratios = {} # The timestep ratio for each stage
self.timesteps_per_stage = {} # The detailed timesteps per stage
self.sigmas_per_stage = {}
self.start_sigmas = {}
self.end_sigmas = {}
self.ori_start_sigmas = {}
# self.init_sigmas()
self.init_sigmas_for_each_stage()
self.sigma_min = self.sigmas[-1].item()
self.sigma_max = self.sigmas[0].item()
self.gamma = gamma
def init_sigmas(self):
"""
initialize the global timesteps and sigmas
"""
num_train_timesteps = self.config.num_train_timesteps
shift = self.config.shift
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
sigmas = timesteps / num_train_timesteps
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
self.timesteps = sigmas * num_train_timesteps
self._step_index = None
self._begin_index = None
self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
def init_sigmas_for_each_stage(self):
"""
Init the timesteps for each stage
"""
self.init_sigmas()
stage_distance = []
stages = self.config.stages
training_steps = self.config.num_train_timesteps
stage_range = self.config.stage_range
# Init the start and end point of each stage
for i_s in range(stages):
# To decide the start and ends point
start_indice = int(stage_range[i_s] * training_steps)
start_indice = max(start_indice, 0)
end_indice = int(stage_range[i_s+1] * training_steps)
end_indice = min(end_indice, training_steps)
start_sigma = self.sigmas[start_indice].item()
end_sigma = self.sigmas[end_indice].item() if end_indice < training_steps else 0.0
self.ori_start_sigmas[i_s] = start_sigma
if i_s != 0:
ori_sigma = 1 - start_sigma
gamma = self.config.gamma
corrected_sigma = (1 / (math.sqrt(1 + (1 / gamma)) * (1 - ori_sigma) + ori_sigma)) * ori_sigma
# corrected_sigma = 1 / (2 - ori_sigma) * ori_sigma
start_sigma = 1 - corrected_sigma
stage_distance.append(start_sigma - end_sigma)
self.start_sigmas[i_s] = start_sigma
self.end_sigmas[i_s] = end_sigma
# Determine the ratio of each stage according to flow length
tot_distance = sum(stage_distance)
for i_s in range(stages):
if i_s == 0:
start_ratio = 0.0
else:
start_ratio = sum(stage_distance[:i_s]) / tot_distance
if i_s == stages - 1:
end_ratio = 1.0
else:
end_ratio = sum(stage_distance[:i_s+1]) / tot_distance
self.timestep_ratios[i_s] = (start_ratio, end_ratio)
# Determine the timesteps and sigmas for each stage
for i_s in range(stages):
timestep_ratio = self.timestep_ratios[i_s]
timestep_max = self.timesteps[int(timestep_ratio[0] * training_steps)]
timestep_min = self.timesteps[min(int(timestep_ratio[1] * training_steps), training_steps - 1)]
timesteps = np.linspace(
timestep_max, timestep_min, training_steps + 1,
)
self.timesteps_per_stage[i_s] = timesteps[:-1] if isinstance(timesteps, torch.Tensor) else torch.from_numpy(timesteps[:-1])
stage_sigmas = np.linspace(
1, 0, training_steps + 1,
)
self.sigmas_per_stage[i_s] = torch.from_numpy(stage_sigmas[:-1])
@property
def step_index(self):
"""
The index counter for current timestep. It will increase 1 after each scheduler step.
"""
return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def _sigma_to_t(self, sigma):
return sigma * self.config.num_train_timesteps
def set_timesteps(self, num_inference_steps: int, stage_index: int, device: Union[str, torch.device] = None):
"""
Setting the timesteps and sigmas for each stage
"""
self.num_inference_steps = num_inference_steps
training_steps = self.config.num_train_timesteps
self.init_sigmas()
stage_timesteps = self.timesteps_per_stage[stage_index]
timestep_max = stage_timesteps[0].item()
timestep_min = stage_timesteps[-1].item()
timesteps = np.linspace(
timestep_max, timestep_min, num_inference_steps,
)
self.timesteps = torch.from_numpy(timesteps).to(device=device)
stage_sigmas = self.sigmas_per_stage[stage_index]
sigma_max = stage_sigmas[0].item()
sigma_min = stage_sigmas[-1].item()
ratios = np.linspace(
sigma_max, sigma_min, num_inference_steps
)
sigmas = torch.from_numpy(ratios).to(device=device)
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
self._step_index = None
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
indices = (schedule_timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
pos = 1 if len(indices) > 1 else 0
return indices[pos].item()
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def step(
self,
model_output: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
sample: torch.FloatTensor,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`):
The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
tuple.
Returns:
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
returned, otherwise a tuple is returned where the first element is the sample tensor.
"""
if (
isinstance(timestep, int)
or isinstance(timestep, torch.IntTensor)
or isinstance(timestep, torch.LongTensor)
):
raise ValueError(
(
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
" one of the `scheduler.timesteps` as a timestep."
),
)
if self.step_index is None:
self._step_index = 0
# Upcast to avoid precision issues when computing prev_sample
sample = sample.to(torch.float32)
sigma = self.sigmas[self.step_index]
sigma_next = self.sigmas[self.step_index + 1]
prev_sample = sample + (sigma_next - sigma) * model_output
# Cast sample back to model compatible dtype
prev_sample = prev_sample.to(model_output.dtype)
# upon completion increase step index by one
self._step_index += 1
if not return_dict:
return (prev_sample,)
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
def __len__(self):
return self.config.num_train_timesteps
# Pyramid Flow's DiT Finetuning Guide
This is the finetuning guide for the DiT in Pyramid Flow. We provide instructions for both autoregressive and non-autoregressive versions. The former is more research oriented and the latter is more stable (but less efficient without temporal pyramid). Please refer to [another document](https://github.com/jy0205/Pyramid-Flow/blob/main/docs/VAE) for VAE finetuning.
## Hardware Requirements
+ DiT finetuning: At least 8 A100 GPUs.
## Prepare the Dataset
The training dataset should be arranged into a json file, with `video`, `text` fields. Since the video vae latent extraction is very slow, we strongly recommend you to pre-extract the video vae latents to save the training time. We provide a video vae latent extraction script in folder `tools`. You can run it with the following command:
```bash
sh scripts/extract_vae_latent.sh
```
(optional) Since the T5 text encoder will cost a lot of GPU memory, pre-extract the text features will save the training memory. We also provide a text feature extraction script in folder `tools`. You can run it with the following command:
```bash
sh scripts/extract_text_feature.sh
```
The final training annotation json file should look like the following format:
```
{"video": video_path, "text": text prompt, "latent": extracted video vae latent, "text_fea": extracted text feature}
```
We provide the example json annotation files for [video](https://github.com/jy0205/Pyramid-Flow/blob/main/annotation/video_text.jsonl) and [image](https://github.com/jy0205/Pyramid-Flow/blob/main/annotation/image_text.jsonl)) training in the `annotation` folder. You can refer them to prepare your training dataset.
## Run Training
We provide two types of training scripts: (1) autoregressive video generation training with temporal pyramid. (2) Full-sequence diffusion training with pyramid-flow for both text-to-image and text-to-video training. This corresponds to the following two script files. Running these training scripts using at least 8 GPUs:
+ `scripts/train_pyramid_flow.sh`: The autoregressive video generation training with temporal pyramid.
```bash
sh scripts/train_pyramid_flow.sh
```
+ `scripts/train_pyramid_flow_without_ar.sh`: Using pyramid-flow for full-sequence diffusion training.
```bash
sh scripts/train_pyramid_flow_without_ar.sh
```
## Tips
+ For the 768p version, make sure to add the args: `--gradient_checkpointing`
+ Param `NUM_FRAMES` should be set to a multiple of 8
+ For the param `video_sync_group`, it indicates the number of process that accepts the same input video, used for temporal pyramid AR training. We recommend to set this value to 4, 8 or 16. (16 is better if you have more GPUs)
+ Make sure to set `NUM_FRAMES % VIDEO_SYNC_GROUP == 0`, `GPUS % VIDEO_SYNC_GROUP == 0`, and `BATCH_SIZE % 4 == 0`
# Pyramid Flow's VAE Training Guide
This is the training guide for a [MAGVIT-v2](https://arxiv.org/abs/2310.05737) like continuous 3D VAE, which should be quite flexible. Feel free to build your own video generative model on this part of VAE training code. Please refer to [another document](https://github.com/jy0205/Pyramid-Flow/blob/main/docs/DiT) for DiT finetuning.
## Hardware Requirements
+ VAE training: At least 8 A100 GPUs.
## Prepare the Dataset
The training of our causal video vae uses both image and video data. Both of them should be arranged into a json file, with `video` or `image` field. The final training annotation json file should look like the following format:
```
# For Video
{"video": video_path}
# For Image
{"image": image_path}
```
## Run Training
The causal video vae undergoes a two-stage training.
+ Stage-1: image and video mixed training
+ Stage-2: pure video training, using context parallel to load video with more video frames
The VAE training script is `scripts/train_causal_video_vae.sh`, run it as follows:
```bash
sh scripts/train_causal_video_vae.sh
```
We also provide a VAE demo `causal_video_vae_demo.ipynb` for image and video reconstruction.
## Tips
+ For stage-1, we use a mixed image and video training. Add the param `--use_image_video_mixed_training` to support the mixed training. We set the image ratio to 0.1 by default.
+ Set the `resolution` to 256 is enough for VAE training.
+ For stage-1, the `max_frames` is set to 17. It means we use 17 sampled video frames for training.
+ For stage-2, we open the param `use_context_parallel` to distribute long video frames to multiple GPUs. Make sure to set `GPUS % CONTEXT_SIZE == 0` and `NUM_FRAMES=17 * CONTEXT_SIZE + 1`
\ No newline at end of file
import os
import torch
def extract_vae_weights(vae_checkpoint_path: str,
save_path: str):
checkpoint = torch.load(vae_checkpoint_path)
weights = checkpoint['model']
new_weights = {}
for name, params in weights.items():
if "dis" in name or "prece" in name or "logvar" in name: continue
name = name.split(".", 1)[1]
new_weights[name] = params
torch.save(new_weights, save_path)
if __name__ == "__main__":
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument("--vae_checkpoint_path", type=str)
parser.add_argument("--save_path", type=str)
args = parser.parse_args()
extract_vae_weights(args.vae_checkpoint_path,
args.save_path)
import os
import json
from pathlib import Path
def get_mapping(path_list):
return {p.stem: str(p.resolve()) for p in path_list}
def generate_image_text(image_root: str,
prompt_root: str,
save_root: str):
image_root = Path(image_root)
prompt_root = Path(prompt_root)
image_path_list = [*image_root.glob("*.jpg"), *image_root.glob("*.png"), *image_root.glob("*.JPEG")]
prompt_path_list = [*prompt_root.glob("*.json")]
image_path_mapping = get_mapping(image_path_list)
prompt_path_mapping = get_mapping(prompt_path_list)
keys = set(image_path_mapping.keys()) & set(prompt_path_mapping.keys())
for key in keys:
with open(prompt_path_mapping[key], "r") as f:
text = json.loads(f.read().strip())['prompt']
tmp = {"image": image_path_mapping[key], "text": text}
with open(os.path.join(save_root, "image_text.jsonl"), "a") as f:
f.write(json.dumps(tmp, ensure_ascii=False) + '\n')
if __name__ == "__main__":
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument("--image_root", type=str)
parser.add_argument("--prompt_root", type=str)
parser.add_argument("--save_root", type=str)
args = parser.parse_args()
generate_image_text(args.image_root, args.prompt_root, args.save_root)
from pathlib import Path
from argparse import ArgumentParser
import json
from typing import Union
def generate_vae_annotation(data_root: str,
data_type: str,
save_path: str):
assert data_type in ['image', 'video']
data_root = Path(data_root)
if data_type == "video":
data_path_list = [*data_root.glob("*.mp4")]
elif data_type == "image":
data_path_list = [*data_root.glob("*.png"), *data_root.glob("*.jpeg"),
*data_root.glob("*.jpg"), *data_root.glob("*.JPEG")]
else:
raise NotImplemented
with open(save_path, "w") as f:
for data_path in data_path_list:
f.write(json.dumps({data_type: str(data_path.resolve())}, ensure_ascii=False) + '\n')
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument("--data_root", type=str)
parser.add_argument("--data_type", type=str)
parser.add_argument("--save_path", type=str)
args = parser.parse_args()
generate_vae_annotation(args.data_root,
args.data_type,
args.save_path)
\ No newline at end of file
# 提取给定VidGen_1M视频对应的文本,设置相应特征保存位置
import os
import json
from pathlib import Path
from argparse import ArgumentParser
def get_video_text(video_root,
caption_json_path,
save_root,
video_latent_root,
text_fea_root):
video_root = Path(video_root)
# 最多支持3级目录
video_path_list = [*video_root.glob("*.mp4"), *video_root.glob("*/*.mp4"), *video_root.glob("*/*/*.mp4")]
vid_path = {p.stem: str(p.resolve()) for p in video_path_list}
with open(caption_json_path, "r") as f:
captions = json.load(f)
vid_caption = {}
for d in captions:
vid_caption[d['vid']] = d['caption']
os.makedirs(video_latent_root, exist_ok=True)
os.makedirs(text_fea_root, exist_ok=True)
with open(os.path.join(save_root, "video_text.jsonl"), "w") as f:
for vid, vpath in vid_path.items():
text = vid_caption[vid]
latent_path = str(Path(os.path.join(video_latent_root, f"{vid}.pt")).resolve())
text_fea_path = str(Path(os.path.join(text_fea_root, f"{vid}-text.pt")).resolve())
f.write(json.dumps({"video": vpath, "text": text, "latent": latent_path, "text_fea": text_fea_path}, ensure_ascii=False) + '\n')
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--video_root", type=str)
parser.add_argument("-cjp", "--caption_json_path", type=str)
parser.add_argument("-sr", "--save_root", type=str)
parser.add_argument("-vlr", "--video_latent_root", type=str)
parser.add_argument("-tfr", "--text_fea_root", type=str)
args = parser.parse_args()
get_video_text(args.video_root,
args.caption_json_path,
args.save_root,
args.video_latent_root,
args.text_fea_root)
\ No newline at end of file
icon.png

68.4 KB

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import json\n",
"import torch\n",
"import numpy as np\n",
"import PIL\n",
"from PIL import Image\n",
"from IPython.display import HTML\n",
"from pyramid_dit import PyramidDiTForVideoGeneration\n",
"from IPython.display import Image as ipython_image\n",
"from diffusers.utils import load_image, export_to_video, export_to_gif"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"variant='diffusion_transformer_image' # For low resolution\n",
"model_name = \"pyramid_flux\"\n",
"\n",
"model_path = \"/home/jinyang06/models/pyramid-flow-miniflux\" # The downloaded checkpoint dir\n",
"model_dtype = 'bf16'\n",
"\n",
"device_id = 0\n",
"torch.cuda.set_device(device_id)\n",
"\n",
"model = PyramidDiTForVideoGeneration(\n",
" model_path,\n",
" model_dtype,\n",
" model_name=model_name,\n",
" model_variant=variant,\n",
")\n",
"\n",
"model.vae.to(\"cuda\")\n",
"model.dit.to(\"cuda\")\n",
"model.text_encoder.to(\"cuda\")\n",
"\n",
"model.vae.enable_tiling()\n",
"\n",
"if model_dtype == \"bf16\":\n",
" torch_dtype = torch.bfloat16 \n",
"elif model_dtype == \"fp16\":\n",
" torch_dtype = torch.float16\n",
"else:\n",
" torch_dtype = torch.float32"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Text-to-Image"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"prompt = \"shoulder and full head portrait of a beautiful 19 year old girl, brunette, smiling, stunning, highly detailed, glamour lighting, HDR, photorealistic, hyperrealism, octane render, unreal engine\"\n",
"\n",
"# now support 3 aspect ratios\n",
"resolution_dict = {\n",
" '1:1' : (1024, 1024),\n",
" '5:3' : (1280, 768),\n",
" '3:5' : (768, 1280),\n",
"}\n",
"\n",
"ratio = '1:1' # 1:1, 5:3, 3:5\n",
"\n",
"width, height = resolution_dict[ratio]\n",
"\n",
"\n",
"with torch.no_grad(), torch.cuda.amp.autocast(enabled=True if model_dtype != 'fp32' else False, dtype=torch_dtype):\n",
" images = model.generate(\n",
" prompt=prompt,\n",
" num_inference_steps=[20, 20, 20],\n",
" height=height,\n",
" width=width,\n",
" temp=1,\n",
" guidance_scale=9.0, \n",
" output_type=\"pil\",\n",
" save_memory=False, \n",
" )\n",
"\n",
"display(images[0])"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
import os
import torch
import sys
import argparse
import random
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from diffusers.utils import export_to_video
from pyramid_dit import PyramidDiTForVideoGeneration
from trainer_misc import init_distributed_mode, init_sequence_parallel_group
import PIL
from PIL import Image
def get_args():
parser = argparse.ArgumentParser('Pytorch Multi-process Script', add_help=False)
parser.add_argument('--model_name', default='pyramid_mmdit', type=str, help="The model name", choices=["pyramid_flux", "pyramid_mmdit"])
parser.add_argument('--model_dtype', default='bf16', type=str, help="The Model Dtype: bf16")
parser.add_argument('--model_path', default='/home/jinyang06/models/pyramid-flow', type=str, help='Set it to the downloaded checkpoint dir')
parser.add_argument('--variant', default='diffusion_transformer_768p', type=str,)
parser.add_argument('--task', default='t2v', type=str, choices=['i2v', 't2v'])
parser.add_argument('--temp', default=16, type=int, help='The generated latent num, num_frames = temp * 8 + 1')
parser.add_argument('--sp_group_size', default=2, type=int, help="The number of gpus used for inference, should be 2 or 4")
parser.add_argument('--sp_proc_num', default=-1, type=int, help="The number of process used for video training, default=-1 means using all process.")
return parser.parse_args()
def main():
args = get_args()
# setup DDP
init_distributed_mode(args)
assert args.world_size == args.sp_group_size, "The sequence parallel size should be DDP world size"
# Enable sequence parallel
init_sequence_parallel_group(args)
device = torch.device('cuda')
rank = args.rank
model_dtype = args.model_dtype
if args.model_name == "pyramid_flux":
assert args.variant != "diffusion_transformer_768p", "The pyramid_flux does not support high resolution now, \
we will release it after finishing training. You can modify the model_name to pyramid_mmdit to support 768p version generation"
model = PyramidDiTForVideoGeneration(
args.model_path,
model_dtype,
model_name=args.model_name,
model_variant=args.variant,
)
model.vae.to(device)
model.dit.to(device)
model.text_encoder.to(device)
model.vae.enable_tiling()
if model_dtype == "bf16":
torch_dtype = torch.bfloat16
elif model_dtype == "fp16":
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
# The video generation config
if args.variant == 'diffusion_transformer_768p':
width = 1280
height = 768
else:
assert args.variant == 'diffusion_transformer_384p'
width = 640
height = 384
if args.task == 't2v':
# prompt = "A movie trailer featuring the adventures of the 30 year old space man wearing a red wool knitted motorcycle helmet, blue sky, salt desert, cinematic style, shot on 35mm film, vivid colors"
prompt = "a cat on the moon, salt desert, cinematic style, shot on 35mm film, vivid colors"
with torch.no_grad(), torch.cuda.amp.autocast(enabled=True if model_dtype != 'fp32' else False, dtype=torch_dtype):
frames = model.generate(
prompt=prompt,
num_inference_steps=[20, 20, 20],
video_num_inference_steps=[10, 10, 10],
height=height,
width=width,
temp=args.temp,
guidance_scale=7.0, # The guidance for the first frame, set it to 7 for 384p variant
video_guidance_scale=5.0, # The guidance for the other video latent
output_type="pil",
save_memory=True, # If you have enough GPU memory, set it to `False` to improve vae decoding speed
cpu_offloading=False, # If OOM, set it to True to reduce memory usage
inference_multigpu=True,
)
if rank == 0:
export_to_video(frames, "./text_to_video_sample.mp4", fps=24)
else:
assert args.task == 'i2v'
image_path = 'assets/the_great_wall.jpg'
image = Image.open(image_path).convert("RGB")
image = image.resize((width, height))
prompt = "FPV flying over the Great Wall"
with torch.no_grad(), torch.cuda.amp.autocast(enabled=True if model_dtype != 'fp32' else False, dtype=torch_dtype):
frames = model.generate_i2v(
prompt=prompt,
input_image=image,
num_inference_steps=[10, 10, 10],
temp=args.temp,
video_guidance_scale=4.0,
output_type="pil",
save_memory=True, # If you have enough GPU memory, set it to `False` to improve vae decoding speed
cpu_offloading=False, # If OOM, set it to True to reduce memory usage
inference_multigpu=True,
)
if rank == 0:
export_to_video(frames, "./image_to_video_sample.mp4", fps=24)
torch.distributed.barrier()
if __name__ == "__main__":
main()
\ No newline at end of file
# 模型唯一标识
modelCode=1126
# 模型名称
modelName=pyramid-flow_pytorch
# 模型描述
modelDescription=快速高质量视频生成
# 应用场景
appScenario=训练,推理,aigc,电商,教育,广媒
# 框架类型
frameType=Pytorch
from .pyramid_dit_for_video_gen_pipeline import PyramidDiTForVideoGeneration
from .flux_modules import FluxSingleTransformerBlock, FluxTransformerBlock, FluxTextEncoderWithMask
from .mmdit_modules import JointTransformerBlock, SD3TextEncoderWithMask
\ No newline at end of file
from .modeling_pyramid_flux import PyramidFluxTransformer
from .modeling_text_encoder import FluxTextEncoderWithMask
from .modeling_flux_block import FluxSingleTransformerBlock, FluxTransformerBlock
\ No newline at end of file
import math
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from diffusers.models.activations import get_activation, FP32SiLU
def get_timestep_embedding(
timesteps: torch.Tensor,
embedding_dim: int,
flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
Args
timesteps (torch.Tensor):
a 1-D Tensor of N indices, one per batch element. These may be fractional.
embedding_dim (int):
the dimension of the output.
flip_sin_to_cos (bool):
Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
downscale_freq_shift (float):
Controls the delta between frequencies between dimensions
scale (float):
Scaling factor applied to the embeddings.
max_period (int):
Controls the maximum frequency of the embeddings
Returns
torch.Tensor: an [N x dim] Tensor of positional embeddings.
"""
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange(
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
# scale embeddings
emb = scale * emb
# concat sine and cosine embeddings
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
# flip sine and cosine embeddings
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
# zero pad
if embedding_dim % 2 == 1:
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
class Timesteps(nn.Module):
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
self.scale = scale
def forward(self, timesteps):
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
scale=self.scale,
)
return t_emb
class TimestepEmbedding(nn.Module):
def __init__(
self,
in_channels: int,
time_embed_dim: int,
act_fn: str = "silu",
out_dim: int = None,
post_act_fn: Optional[str] = None,
cond_proj_dim=None,
sample_proj_bias=True,
):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
if cond_proj_dim is not None:
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
else:
self.cond_proj = None
self.act = get_activation(act_fn)
if out_dim is not None:
time_embed_dim_out = out_dim
else:
time_embed_dim_out = time_embed_dim
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
if post_act_fn is None:
self.post_act = None
else:
self.post_act = get_activation(post_act_fn)
def forward(self, sample, condition=None):
if condition is not None:
sample = sample + self.cond_proj(condition)
sample = self.linear_1(sample)
if self.act is not None:
sample = self.act(sample)
sample = self.linear_2(sample)
if self.post_act is not None:
sample = self.post_act(sample)
return sample
class PixArtAlphaTextProjection(nn.Module):
"""
Projects caption embeddings. Also handles dropout for classifier-free guidance.
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
"""
def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"):
super().__init__()
if out_features is None:
out_features = hidden_size
self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
if act_fn == "gelu_tanh":
self.act_1 = nn.GELU(approximate="tanh")
elif act_fn == "silu":
self.act_1 = nn.SiLU()
elif act_fn == "silu_fp32":
self.act_1 = FP32SiLU()
else:
raise ValueError(f"Unknown activation function: {act_fn}")
self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True)
def forward(self, caption):
hidden_states = self.linear_1(caption)
hidden_states = self.act_1(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, pooled_projection_dim):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
def forward(self, timestep, guidance, pooled_projection):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
guidance_proj = self.time_proj(guidance)
guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) # (N, D)
time_guidance_emb = timesteps_emb + guidance_emb
pooled_projections = self.text_embedder(pooled_projection)
conditioning = time_guidance_emb + pooled_projections
return conditioning
class CombinedTimestepTextProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, pooled_projection_dim):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
def forward(self, timestep, pooled_projection):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
pooled_projections = self.text_embedder(pooled_projection)
conditioning = timesteps_emb + pooled_projections
return conditioning
\ No newline at end of file
from typing import Any, Dict, List, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import inspect
from einops import rearrange
from diffusers.utils import deprecate
from diffusers.models.activations import GEGLU, GELU, ApproximateGELU, SwiGLU
from .modeling_normalization import (
AdaLayerNormContinuous, AdaLayerNormZero,
AdaLayerNormZeroSingle, FP32LayerNorm, RMSNorm
)
from trainer_misc import (
is_sequence_parallel_initialized,
get_sequence_parallel_group,
get_sequence_parallel_world_size,
all_to_all,
)
try:
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
from flash_attn.bert_padding import pad_input, unpad_input, index_first_axis
from flash_attn.flash_attn_interface import flash_attn_varlen_func
except:
flash_attn_func = None
flash_attn_qkvpacked_func = None
flash_attn_varlen_func = None
def apply_rope(xq, xk, freqs_cis):
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
class FeedForward(nn.Module):
r"""
A feed-forward layer.
Parameters:
dim (`int`): The number of channels in the input.
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
"""
def __init__(
self,
dim: int,
dim_out: Optional[int] = None,
mult: int = 4,
dropout: float = 0.0,
activation_fn: str = "geglu",
final_dropout: bool = False,
inner_dim=None,
bias: bool = True,
):
super().__init__()
if inner_dim is None:
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
if activation_fn == "gelu":
act_fn = GELU(dim, inner_dim, bias=bias)
if activation_fn == "gelu-approximate":
act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
elif activation_fn == "geglu":
act_fn = GEGLU(dim, inner_dim, bias=bias)
elif activation_fn == "geglu-approximate":
act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
elif activation_fn == "swiglu":
act_fn = SwiGLU(dim, inner_dim, bias=bias)
self.net = nn.ModuleList([])
# project in
self.net.append(act_fn)
# project dropout
self.net.append(nn.Dropout(dropout))
# project out
self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
if final_dropout:
self.net.append(nn.Dropout(dropout))
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states
class SequenceParallelVarlenFlashSelfAttentionWithT5Mask:
def __init__(self):
pass
def __call__(
self, query, key, value, encoder_query, encoder_key, encoder_value,
heads, scale, hidden_length=None, image_rotary_emb=None, encoder_attention_mask=None,
):
assert encoder_attention_mask is not None, "The encoder-hidden mask needed to be set"
batch_size = query.shape[0]
qkv_list = []
num_stages = len(hidden_length)
encoder_qkv = torch.stack([encoder_query, encoder_key, encoder_value], dim=2) # [bs, sub_seq, 3, head, head_dim]
qkv = torch.stack([query, key, value], dim=2) # [bs, sub_seq, 3, head, head_dim]
# To sync the encoder query, key and values
sp_group = get_sequence_parallel_group()
sp_group_size = get_sequence_parallel_world_size()
encoder_qkv = all_to_all(encoder_qkv, sp_group, sp_group_size, scatter_dim=3, gather_dim=1) # [bs, seq, 3, sub_head, head_dim]
output_hidden = torch.zeros_like(qkv[:,:,0])
output_encoder_hidden = torch.zeros_like(encoder_qkv[:,:,0])
encoder_length = encoder_qkv.shape[1]
i_sum = 0
for i_p, length in enumerate(hidden_length):
# get the query, key, value from padding sequence
encoder_qkv_tokens = encoder_qkv[i_p::num_stages]
qkv_tokens = qkv[:, i_sum:i_sum+length]
qkv_tokens = all_to_all(qkv_tokens, sp_group, sp_group_size, scatter_dim=3, gather_dim=1) # [bs, seq, 3, sub_head, head_dim]
concat_qkv_tokens = torch.cat([encoder_qkv_tokens, qkv_tokens], dim=1) # [bs, pad_seq, 3, nhead, dim]
if image_rotary_emb is not None:
concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1] = apply_rope(concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1], image_rotary_emb[i_p])
indices = encoder_attention_mask[i_p]['indices']
qkv_list.append(index_first_axis(rearrange(concat_qkv_tokens, "b s ... -> (b s) ..."), indices))
i_sum += length
token_lengths = [x_.shape[0] for x_ in qkv_list]
qkv = torch.cat(qkv_list, dim=0)
query, key, value = qkv.unbind(1)
cu_seqlens = torch.cat([x_['seqlens_in_batch'] for x_ in encoder_attention_mask], dim=0)
max_seqlen_q = cu_seqlens.max().item()
max_seqlen_k = max_seqlen_q
cu_seqlens_q = F.pad(torch.cumsum(cu_seqlens, dim=0, dtype=torch.int32), (1, 0))
cu_seqlens_k = cu_seqlens_q.clone()
output = flash_attn_varlen_func(
query,
key,
value,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
dropout_p=0.0,
causal=False,
softmax_scale=scale,
)
# To merge the tokens
i_sum = 0;token_sum = 0
for i_p, length in enumerate(hidden_length):
tot_token_num = token_lengths[i_p]
stage_output = output[token_sum : token_sum + tot_token_num]
stage_output = pad_input(stage_output, encoder_attention_mask[i_p]['indices'], batch_size, encoder_length + length * sp_group_size)
stage_encoder_hidden_output = stage_output[:, :encoder_length]
stage_hidden_output = stage_output[:, encoder_length:]
stage_hidden_output = all_to_all(stage_hidden_output, sp_group, sp_group_size, scatter_dim=1, gather_dim=2)
output_hidden[:, i_sum:i_sum+length] = stage_hidden_output
output_encoder_hidden[i_p::num_stages] = stage_encoder_hidden_output
token_sum += tot_token_num
i_sum += length
output_encoder_hidden = all_to_all(output_encoder_hidden, sp_group, sp_group_size, scatter_dim=1, gather_dim=2)
output_hidden = output_hidden.flatten(2, 3)
output_encoder_hidden = output_encoder_hidden.flatten(2, 3)
return output_hidden, output_encoder_hidden
class VarlenFlashSelfAttentionWithT5Mask:
def __init__(self):
pass
def __call__(
self, query, key, value, encoder_query, encoder_key, encoder_value,
heads, scale, hidden_length=None, image_rotary_emb=None, encoder_attention_mask=None,
):
assert encoder_attention_mask is not None, "The encoder-hidden mask needed to be set"
batch_size = query.shape[0]
output_hidden = torch.zeros_like(query)
output_encoder_hidden = torch.zeros_like(encoder_query)
encoder_length = encoder_query.shape[1]
qkv_list = []
num_stages = len(hidden_length)
encoder_qkv = torch.stack([encoder_query, encoder_key, encoder_value], dim=2) # [bs, sub_seq, 3, head, head_dim]
qkv = torch.stack([query, key, value], dim=2) # [bs, sub_seq, 3, head, head_dim]
i_sum = 0
for i_p, length in enumerate(hidden_length):
encoder_qkv_tokens = encoder_qkv[i_p::num_stages]
qkv_tokens = qkv[:, i_sum:i_sum+length]
concat_qkv_tokens = torch.cat([encoder_qkv_tokens, qkv_tokens], dim=1) # [bs, tot_seq, 3, nhead, dim]
if image_rotary_emb is not None:
concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1] = apply_rope(concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1], image_rotary_emb[i_p])
indices = encoder_attention_mask[i_p]['indices']
qkv_list.append(index_first_axis(rearrange(concat_qkv_tokens, "b s ... -> (b s) ..."), indices))
i_sum += length
token_lengths = [x_.shape[0] for x_ in qkv_list]
qkv = torch.cat(qkv_list, dim=0)
query, key, value = qkv.unbind(1)
cu_seqlens = torch.cat([x_['seqlens_in_batch'] for x_ in encoder_attention_mask], dim=0)
max_seqlen_q = cu_seqlens.max().item()
max_seqlen_k = max_seqlen_q
cu_seqlens_q = F.pad(torch.cumsum(cu_seqlens, dim=0, dtype=torch.int32), (1, 0))
cu_seqlens_k = cu_seqlens_q.clone()
output = flash_attn_varlen_func(
query,
key,
value,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
dropout_p=0.0,
causal=False,
softmax_scale=scale,
)
# To merge the tokens
i_sum = 0;token_sum = 0
for i_p, length in enumerate(hidden_length):
tot_token_num = token_lengths[i_p]
stage_output = output[token_sum : token_sum + tot_token_num]
stage_output = pad_input(stage_output, encoder_attention_mask[i_p]['indices'], batch_size, encoder_length + length)
stage_encoder_hidden_output = stage_output[:, :encoder_length]
stage_hidden_output = stage_output[:, encoder_length:]
output_hidden[:, i_sum:i_sum+length] = stage_hidden_output
output_encoder_hidden[i_p::num_stages] = stage_encoder_hidden_output
token_sum += tot_token_num
i_sum += length
output_hidden = output_hidden.flatten(2, 3)
output_encoder_hidden = output_encoder_hidden.flatten(2, 3)
return output_hidden, output_encoder_hidden
class SequenceParallelVarlenSelfAttentionWithT5Mask:
def __init__(self):
pass
def __call__(
self, query, key, value, encoder_query, encoder_key, encoder_value,
heads, scale, hidden_length=None, image_rotary_emb=None, attention_mask=None,
):
assert attention_mask is not None, "The attention mask needed to be set"
num_stages = len(hidden_length)
encoder_qkv = torch.stack([encoder_query, encoder_key, encoder_value], dim=2) # [bs, sub_seq, 3, head, head_dim]
qkv = torch.stack([query, key, value], dim=2) # [bs, sub_seq, 3, head, head_dim]
# To sync the encoder query, key and values
sp_group = get_sequence_parallel_group()
sp_group_size = get_sequence_parallel_world_size()
encoder_qkv = all_to_all(encoder_qkv, sp_group, sp_group_size, scatter_dim=3, gather_dim=1) # [bs, seq, 3, sub_head, head_dim]
encoder_length = encoder_qkv.shape[1]
i_sum = 0
output_encoder_hidden_list = []
output_hidden_list = []
for i_p, length in enumerate(hidden_length):
encoder_qkv_tokens = encoder_qkv[i_p::num_stages]
qkv_tokens = qkv[:, i_sum:i_sum+length]
qkv_tokens = all_to_all(qkv_tokens, sp_group, sp_group_size, scatter_dim=3, gather_dim=1) # [bs, seq, 3, sub_head, head_dim]
concat_qkv_tokens = torch.cat([encoder_qkv_tokens, qkv_tokens], dim=1) # [bs, tot_seq, 3, nhead, dim]
if image_rotary_emb is not None:
concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1] = apply_rope(concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1], image_rotary_emb[i_p])
query, key, value = concat_qkv_tokens.unbind(2) # [bs, tot_seq, nhead, dim]
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
stage_hidden_states = F.scaled_dot_product_attention(
query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask[i_p],
)
stage_hidden_states = stage_hidden_states.transpose(1, 2) # [bs, tot_seq, nhead, dim]
output_encoder_hidden_list.append(stage_hidden_states[:, :encoder_length])
output_hidden = stage_hidden_states[:, encoder_length:]
output_hidden = all_to_all(output_hidden, sp_group, sp_group_size, scatter_dim=1, gather_dim=2)
output_hidden_list.append(output_hidden)
i_sum += length
output_encoder_hidden = torch.stack(output_encoder_hidden_list, dim=1) # [b n s nhead d]
output_encoder_hidden = rearrange(output_encoder_hidden, 'b n s h d -> (b n) s h d')
output_encoder_hidden = all_to_all(output_encoder_hidden, sp_group, sp_group_size, scatter_dim=1, gather_dim=2)
output_encoder_hidden = output_encoder_hidden.flatten(2, 3)
output_hidden = torch.cat(output_hidden_list, dim=1).flatten(2, 3)
return output_hidden, output_encoder_hidden
class VarlenSelfAttentionWithT5Mask:
def __init__(self):
pass
def __call__(
self, query, key, value, encoder_query, encoder_key, encoder_value,
heads, scale, hidden_length=None, image_rotary_emb=None, attention_mask=None,
):
assert attention_mask is not None, "The attention mask needed to be set"
encoder_length = encoder_query.shape[1]
num_stages = len(hidden_length)
encoder_qkv = torch.stack([encoder_query, encoder_key, encoder_value], dim=2) # [bs, sub_seq, 3, head, head_dim]
qkv = torch.stack([query, key, value], dim=2) # [bs, sub_seq, 3, head, head_dim]
i_sum = 0
output_encoder_hidden_list = []
output_hidden_list = []
for i_p, length in enumerate(hidden_length):
encoder_qkv_tokens = encoder_qkv[i_p::num_stages]
qkv_tokens = qkv[:, i_sum:i_sum+length]
concat_qkv_tokens = torch.cat([encoder_qkv_tokens, qkv_tokens], dim=1) # [bs, tot_seq, 3, nhead, dim]
if image_rotary_emb is not None:
concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1] = apply_rope(concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1], image_rotary_emb[i_p])
query, key, value = concat_qkv_tokens.unbind(2) # [bs, tot_seq, nhead, dim]
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
# with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=False, enable_mem_efficient=True):
stage_hidden_states = F.scaled_dot_product_attention(
query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask[i_p],
)
stage_hidden_states = stage_hidden_states.transpose(1, 2).flatten(2, 3) # [bs, tot_seq, dim]
output_encoder_hidden_list.append(stage_hidden_states[:, :encoder_length])
output_hidden_list.append(stage_hidden_states[:, encoder_length:])
i_sum += length
output_encoder_hidden = torch.stack(output_encoder_hidden_list, dim=1) # [b n s d]
output_encoder_hidden = rearrange(output_encoder_hidden, 'b n s d -> (b n) s d')
output_hidden = torch.cat(output_hidden_list, dim=1)
return output_hidden, output_encoder_hidden
class SequenceParallelVarlenFlashAttnSingle:
def __init__(self):
pass
def __call__(
self, query, key, value, heads, scale,
hidden_length=None, image_rotary_emb=None, encoder_attention_mask=None,
):
assert encoder_attention_mask is not None, "The encoder-hidden mask needed to be set"
batch_size = query.shape[0]
qkv_list = []
num_stages = len(hidden_length)
qkv = torch.stack([query, key, value], dim=2) # [bs, sub_seq, 3, head, head_dim]
output_hidden = torch.zeros_like(qkv[:,:,0])
sp_group = get_sequence_parallel_group()
sp_group_size = get_sequence_parallel_world_size()
i_sum = 0
for i_p, length in enumerate(hidden_length):
# get the query, key, value from padding sequence
qkv_tokens = qkv[:, i_sum:i_sum+length]
qkv_tokens = all_to_all(qkv_tokens, sp_group, sp_group_size, scatter_dim=3, gather_dim=1) # [bs, seq, 3, sub_head, head_dim]
if image_rotary_emb is not None:
qkv_tokens[:,:,0], qkv_tokens[:,:,1] = apply_rope(qkv_tokens[:,:,0], qkv_tokens[:,:,1], image_rotary_emb[i_p])
indices = encoder_attention_mask[i_p]['indices']
qkv_list.append(index_first_axis(rearrange(qkv_tokens, "b s ... -> (b s) ..."), indices))
i_sum += length
token_lengths = [x_.shape[0] for x_ in qkv_list]
qkv = torch.cat(qkv_list, dim=0)
query, key, value = qkv.unbind(1)
cu_seqlens = torch.cat([x_['seqlens_in_batch'] for x_ in encoder_attention_mask], dim=0)
max_seqlen_q = cu_seqlens.max().item()
max_seqlen_k = max_seqlen_q
cu_seqlens_q = F.pad(torch.cumsum(cu_seqlens, dim=0, dtype=torch.int32), (1, 0))
cu_seqlens_k = cu_seqlens_q.clone()
output = flash_attn_varlen_func(
query,
key,
value,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
dropout_p=0.0,
causal=False,
softmax_scale=scale,
)
# To merge the tokens
i_sum = 0;token_sum = 0
for i_p, length in enumerate(hidden_length):
tot_token_num = token_lengths[i_p]
stage_output = output[token_sum : token_sum + tot_token_num]
stage_output = pad_input(stage_output, encoder_attention_mask[i_p]['indices'], batch_size, length * sp_group_size)
stage_hidden_output = all_to_all(stage_output, sp_group, sp_group_size, scatter_dim=1, gather_dim=2)
output_hidden[:, i_sum:i_sum+length] = stage_hidden_output
token_sum += tot_token_num
i_sum += length
output_hidden = output_hidden.flatten(2, 3)
return output_hidden
class VarlenFlashSelfAttnSingle:
def __init__(self):
pass
def __call__(
self, query, key, value, heads, scale,
hidden_length=None, image_rotary_emb=None, encoder_attention_mask=None,
):
assert encoder_attention_mask is not None, "The encoder-hidden mask needed to be set"
batch_size = query.shape[0]
output_hidden = torch.zeros_like(query)
qkv_list = []
num_stages = len(hidden_length)
qkv = torch.stack([query, key, value], dim=2) # [bs, sub_seq, 3, head, head_dim]
i_sum = 0
for i_p, length in enumerate(hidden_length):
qkv_tokens = qkv[:, i_sum:i_sum+length]
if image_rotary_emb is not None:
qkv_tokens[:,:,0], qkv_tokens[:,:,1] = apply_rope(qkv_tokens[:,:,0], qkv_tokens[:,:,1], image_rotary_emb[i_p])
indices = encoder_attention_mask[i_p]['indices']
qkv_list.append(index_first_axis(rearrange(qkv_tokens, "b s ... -> (b s) ..."), indices))
i_sum += length
token_lengths = [x_.shape[0] for x_ in qkv_list]
qkv = torch.cat(qkv_list, dim=0)
query, key, value = qkv.unbind(1)
cu_seqlens = torch.cat([x_['seqlens_in_batch'] for x_ in encoder_attention_mask], dim=0)
max_seqlen_q = cu_seqlens.max().item()
max_seqlen_k = max_seqlen_q
cu_seqlens_q = F.pad(torch.cumsum(cu_seqlens, dim=0, dtype=torch.int32), (1, 0))
cu_seqlens_k = cu_seqlens_q.clone()
output = flash_attn_varlen_func(
query,
key,
value,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
dropout_p=0.0,
causal=False,
softmax_scale=scale,
)
# To merge the tokens
i_sum = 0;token_sum = 0
for i_p, length in enumerate(hidden_length):
tot_token_num = token_lengths[i_p]
stage_output = output[token_sum : token_sum + tot_token_num]
stage_output = pad_input(stage_output, encoder_attention_mask[i_p]['indices'], batch_size, length)
output_hidden[:, i_sum:i_sum+length] = stage_output
token_sum += tot_token_num
i_sum += length
output_hidden = output_hidden.flatten(2, 3)
return output_hidden
class SequenceParallelVarlenAttnSingle:
def __init__(self):
pass
def __call__(
self, query, key, value, heads, scale,
hidden_length=None, image_rotary_emb=None, attention_mask=None,
):
assert attention_mask is not None, "The attention mask needed to be set"
num_stages = len(hidden_length)
qkv = torch.stack([query, key, value], dim=2) # [bs, sub_seq, 3, head, head_dim]
# To sync the encoder query, key and values
sp_group = get_sequence_parallel_group()
sp_group_size = get_sequence_parallel_world_size()
i_sum = 0
output_hidden_list = []
for i_p, length in enumerate(hidden_length):
qkv_tokens = qkv[:, i_sum:i_sum+length]
qkv_tokens = all_to_all(qkv_tokens, sp_group, sp_group_size, scatter_dim=3, gather_dim=1) # [bs, seq, 3, sub_head, head_dim]
if image_rotary_emb is not None:
qkv_tokens[:,:,0], qkv_tokens[:,:,1] = apply_rope(qkv_tokens[:,:,0], qkv_tokens[:,:,1], image_rotary_emb[i_p])
query, key, value = qkv_tokens.unbind(2) # [bs, tot_seq, nhead, dim]
query = query.transpose(1, 2).contiguous()
key = key.transpose(1, 2).contiguous()
value = value.transpose(1, 2).contiguous()
stage_hidden_states = F.scaled_dot_product_attention(
query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask[i_p],
)
stage_hidden_states = stage_hidden_states.transpose(1, 2) # [bs, tot_seq, nhead, dim]
output_hidden = stage_hidden_states
output_hidden = all_to_all(output_hidden, sp_group, sp_group_size, scatter_dim=1, gather_dim=2)
output_hidden_list.append(output_hidden)
i_sum += length
output_hidden = torch.cat(output_hidden_list, dim=1).flatten(2, 3)
return output_hidden
class VarlenSelfAttnSingle:
def __init__(self):
pass
def __call__(
self, query, key, value, heads, scale,
hidden_length=None, image_rotary_emb=None, attention_mask=None,
):
assert attention_mask is not None, "The attention mask needed to be set"
num_stages = len(hidden_length)
qkv = torch.stack([query, key, value], dim=2) # [bs, sub_seq, 3, head, head_dim]
i_sum = 0
output_hidden_list = []
for i_p, length in enumerate(hidden_length):
qkv_tokens = qkv[:, i_sum:i_sum+length]
if image_rotary_emb is not None:
qkv_tokens[:,:,0], qkv_tokens[:,:,1] = apply_rope(qkv_tokens[:,:,0], qkv_tokens[:,:,1], image_rotary_emb[i_p])
query, key, value = qkv_tokens.unbind(2)
query = query.transpose(1, 2).contiguous()
key = key.transpose(1, 2).contiguous()
value = value.transpose(1, 2).contiguous()
stage_hidden_states = F.scaled_dot_product_attention(
query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask[i_p],
)
stage_hidden_states = stage_hidden_states.transpose(1, 2).flatten(2, 3) # [bs, tot_seq, dim]
output_hidden_list.append(stage_hidden_states)
i_sum += length
output_hidden = torch.cat(output_hidden_list, dim=1)
return output_hidden
class Attention(nn.Module):
def __init__(
self,
query_dim: int,
cross_attention_dim: Optional[int] = None,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = False,
qk_norm: Optional[str] = None,
added_kv_proj_dim: Optional[int] = None,
added_proj_bias: Optional[bool] = True,
out_bias: bool = True,
only_cross_attention: bool = False,
eps: float = 1e-5,
processor: Optional["AttnProcessor"] = None,
out_dim: int = None,
context_pre_only=None,
pre_only=False,
):
super().__init__()
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.inner_kv_dim = self.inner_dim
self.query_dim = query_dim
self.use_bias = bias
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.dropout = dropout
self.out_dim = out_dim if out_dim is not None else query_dim
self.context_pre_only = context_pre_only
self.pre_only = pre_only
self.scale = dim_head**-0.5
self.heads = out_dim // dim_head if out_dim is not None else heads
self.added_kv_proj_dim = added_kv_proj_dim
self.only_cross_attention = only_cross_attention
if self.added_kv_proj_dim is None and self.only_cross_attention:
raise ValueError(
"`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
)
if qk_norm is None:
self.norm_q = None
self.norm_k = None
elif qk_norm == "rms_norm":
self.norm_q = RMSNorm(dim_head, eps=eps)
self.norm_k = RMSNorm(dim_head, eps=eps)
else:
raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'")
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
if not self.only_cross_attention:
# only relevant for the `AddedKVProcessor` classes
self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
else:
self.to_k = None
self.to_v = None
self.added_proj_bias = added_proj_bias
if self.added_kv_proj_dim is not None:
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
if self.context_pre_only is not None:
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
if not self.pre_only:
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(nn.Dropout(dropout))
if self.context_pre_only is not None and not self.context_pre_only:
self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)
if qk_norm is not None and added_kv_proj_dim is not None:
if qk_norm == "fp32_layer_norm":
self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
elif qk_norm == "rms_norm":
self.norm_added_q = RMSNorm(dim_head, eps=eps)
self.norm_added_k = RMSNorm(dim_head, eps=eps)
else:
self.norm_added_q = None
self.norm_added_k = None
# set attention processor
self.set_processor(processor)
def set_processor(self, processor: "AttnProcessor") -> None:
self.processor = processor
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
hidden_length: List = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.processor(
self,
hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
attention_mask=attention_mask,
hidden_length=hidden_length,
image_rotary_emb=image_rotary_emb,
)
class FluxSingleAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
def __init__(self, use_flash_attn=False):
self.use_flash_attn = use_flash_attn
if self.use_flash_attn:
if is_sequence_parallel_initialized():
self.varlen_flash_attn = SequenceParallelVarlenFlashAttnSingle()
else:
self.varlen_flash_attn = VarlenFlashSelfAttnSingle()
else:
if is_sequence_parallel_initialized():
self.varlen_attn = SequenceParallelVarlenAttnSingle()
else:
self.varlen_attn = VarlenSelfAttnSingle()
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
hidden_length: List = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(query.shape[0], -1, attn.heads, head_dim)
key = key.view(key.shape[0], -1, attn.heads, head_dim)
value = value.view(value.shape[0], -1, attn.heads, head_dim)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
if self.use_flash_attn:
hidden_states = self.varlen_flash_attn(
query, key, value,
attn.heads, attn.scale, hidden_length,
image_rotary_emb, encoder_attention_mask,
)
else:
hidden_states = self.varlen_attn(
query, key, value,
attn.heads, attn.scale, hidden_length,
image_rotary_emb, attention_mask,
)
return hidden_states
class FluxAttnProcessor2_0:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __init__(self, use_flash_attn=False):
self.use_flash_attn = use_flash_attn
if self.use_flash_attn:
if is_sequence_parallel_initialized():
self.varlen_flash_attn = SequenceParallelVarlenFlashSelfAttentionWithT5Mask()
else:
self.varlen_flash_attn = VarlenFlashSelfAttentionWithT5Mask()
else:
if is_sequence_parallel_initialized():
self.varlen_attn = SequenceParallelVarlenSelfAttentionWithT5Mask()
else:
self.varlen_attn = VarlenSelfAttentionWithT5Mask()
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
hidden_length: List = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(query.shape[0], -1, attn.heads, head_dim)
key = key.view(key.shape[0], -1, attn.heads, head_dim)
value = value.view(value.shape[0], -1, attn.heads, head_dim)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
encoder_hidden_states_query_proj.shape[0], -1, attn.heads, head_dim
)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
encoder_hidden_states_key_proj.shape[0], -1, attn.heads, head_dim
)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
encoder_hidden_states_value_proj.shape[0], -1, attn.heads, head_dim
)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
if self.use_flash_attn:
hidden_states, encoder_hidden_states = self.varlen_flash_attn(
query, key, value,
encoder_hidden_states_query_proj, encoder_hidden_states_key_proj,
encoder_hidden_states_value_proj, attn.heads, attn.scale, hidden_length,
image_rotary_emb, encoder_attention_mask,
)
else:
hidden_states, encoder_hidden_states = self.varlen_attn(
query, key, value,
encoder_hidden_states_query_proj, encoder_hidden_states_key_proj,
encoder_hidden_states_value_proj, attn.heads, attn.scale, hidden_length,
image_rotary_emb, attention_mask,
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
class FluxSingleTransformerBlock(nn.Module):
r"""
A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
Reference: https://arxiv.org/abs/2403.03206
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
processing of `context` conditions.
"""
def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0, use_flash_attn=False):
super().__init__()
self.mlp_hidden_dim = int(dim * mlp_ratio)
self.norm = AdaLayerNormZeroSingle(dim)
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
self.act_mlp = nn.GELU(approximate="tanh")
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
processor = FluxSingleAttnProcessor2_0(use_flash_attn)
self.attn = Attention(
query_dim=dim,
cross_attention_dim=None,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=True,
processor=processor,
qk_norm="rms_norm",
eps=1e-6,
pre_only=True,
)
def forward(
self,
hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
encoder_attention_mask=None,
attention_mask=None,
hidden_length=None,
image_rotary_emb=None,
):
residual = hidden_states
norm_hidden_states, gate = self.norm(hidden_states, emb=temb, hidden_length=hidden_length)
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=None,
encoder_attention_mask=encoder_attention_mask,
attention_mask=attention_mask,
hidden_length=hidden_length,
image_rotary_emb=image_rotary_emb,
)
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
hidden_states = gate * self.proj_out(hidden_states)
hidden_states = residual + hidden_states
if hidden_states.dtype == torch.float16:
hidden_states = hidden_states.clip(-65504, 65504)
return hidden_states
class FluxTransformerBlock(nn.Module):
r"""
A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
Reference: https://arxiv.org/abs/2403.03206
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
processing of `context` conditions.
"""
def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6, use_flash_attn=False):
super().__init__()
self.norm1 = AdaLayerNormZero(dim)
self.norm1_context = AdaLayerNormZero(dim)
if hasattr(F, "scaled_dot_product_attention"):
processor = FluxAttnProcessor2_0(use_flash_attn)
else:
raise ValueError(
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
)
self.attn = Attention(
query_dim=dim,
cross_attention_dim=None,
added_kv_proj_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
context_pre_only=False,
bias=True,
processor=processor,
qk_norm=qk_norm,
eps=eps,
)
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
def forward(
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor,
encoder_attention_mask: torch.FloatTensor,
temb: torch.FloatTensor,
attention_mask: torch.FloatTensor = None,
hidden_length: List = None,
image_rotary_emb=None,
):
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb, hidden_length=hidden_length)
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
encoder_hidden_states, emb=temb
)
# Attention.
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
attention_mask=attention_mask,
hidden_length=hidden_length,
image_rotary_emb=image_rotary_emb,
)
# Process attention outputs for the `hidden_states`.
attn_output = gate_msa * attn_output
hidden_states = hidden_states + attn_output
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
ff_output = self.ff(norm_hidden_states)
ff_output = gate_mlp * ff_output
hidden_states = hidden_states + ff_output
# Process attention outputs for the `encoder_hidden_states`.
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
encoder_hidden_states = encoder_hidden_states + context_attn_output
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
context_ff_output = self.ff_context(norm_encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
if encoder_hidden_states.dtype == torch.float16:
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
return encoder_hidden_states, hidden_states
import numbers
from typing import Dict, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from diffusers.utils import is_torch_version
if is_torch_version(">=", "2.1.0"):
LayerNorm = nn.LayerNorm
else:
# Has optional bias parameter compared to torch layer norm
# TODO: replace with torch layernorm once min required torch version >= 2.1
class LayerNorm(nn.Module):
def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True):
super().__init__()
self.eps = eps
if isinstance(dim, numbers.Integral):
dim = (dim,)
self.dim = torch.Size(dim)
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim))
self.bias = nn.Parameter(torch.zeros(dim)) if bias else None
else:
self.weight = None
self.bias = None
def forward(self, input):
return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps)
class FP32LayerNorm(nn.LayerNorm):
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
origin_dtype = inputs.dtype
return F.layer_norm(
inputs.float(),
self.normalized_shape,
self.weight.float() if self.weight is not None else None,
self.bias.float() if self.bias is not None else None,
self.eps,
).to(origin_dtype)
class RMSNorm(nn.Module):
def __init__(self, dim, eps: float, elementwise_affine: bool = True):
super().__init__()
self.eps = eps
if isinstance(dim, numbers.Integral):
dim = (dim,)
self.dim = torch.Size(dim)
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim))
else:
self.weight = None
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
if self.weight is not None:
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
hidden_states = hidden_states * self.weight
else:
hidden_states = hidden_states.to(input_dtype)
return hidden_states
class AdaLayerNormContinuous(nn.Module):
def __init__(
self,
embedding_dim: int,
conditioning_embedding_dim: int,
# NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
# because the output is immediately scaled and shifted by the projected conditioning embeddings.
# Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
# However, this is how it was implemented in the original code, and it's rather likely you should
# set `elementwise_affine` to False.
elementwise_affine=True,
eps=1e-5,
bias=True,
norm_type="layer_norm",
):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
if norm_type == "layer_norm":
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
elif norm_type == "rms_norm":
self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
else:
raise ValueError(f"unknown norm_type {norm_type}")
def forward_with_pad(self, x: torch.Tensor, conditioning_embedding: torch.Tensor, hidden_length=None) -> torch.Tensor:
assert hidden_length is not None
emb = self.linear(self.silu(conditioning_embedding).to(x.dtype))
batch_emb = torch.zeros_like(x).repeat(1, 1, 2)
i_sum = 0
num_stages = len(hidden_length)
for i_p, length in enumerate(hidden_length):
batch_emb[:, i_sum:i_sum+length] = emb[i_p::num_stages][:,None]
i_sum += length
batch_scale, batch_shift = torch.chunk(batch_emb, 2, dim=2)
x = self.norm(x) * (1 + batch_scale) + batch_shift
return x
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor, hidden_length=None) -> torch.Tensor:
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
if hidden_length is not None:
return self.forward_with_pad(x, conditioning_embedding, hidden_length)
emb = self.linear(self.silu(conditioning_embedding).to(x.dtype))
scale, shift = torch.chunk(emb, 2, dim=1)
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
return x
class AdaLayerNormZero(nn.Module):
r"""
Norm layer adaptive layer norm zero (adaLN-Zero).
Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_embeddings (`int`): The size of the embeddings dictionary.
"""
def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None):
super().__init__()
self.emb = None
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
def forward_with_pad(
self,
x: torch.Tensor,
timestep: Optional[torch.Tensor] = None,
class_labels: Optional[torch.LongTensor] = None,
hidden_dtype: Optional[torch.dtype] = None,
emb: Optional[torch.Tensor] = None,
hidden_length: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# hidden_length: [[20, 30], [30, 40], [50, 60]]
# x: [bs, seq_len, dim]
if self.emb is not None:
emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
emb = self.linear(self.silu(emb))
batch_emb = torch.zeros_like(x).repeat(1, 1, 6)
i_sum = 0
num_stages = len(hidden_length)
for i_p, length in enumerate(hidden_length):
batch_emb[:, i_sum:i_sum+length] = emb[i_p::num_stages][:,None]
i_sum += length
batch_shift_msa, batch_scale_msa, batch_gate_msa, batch_shift_mlp, batch_scale_mlp, batch_gate_mlp = batch_emb.chunk(6, dim=2)
x = self.norm(x) * (1 + batch_scale_msa) + batch_shift_msa
return x, batch_gate_msa, batch_shift_mlp, batch_scale_mlp, batch_gate_mlp
def forward(
self,
x: torch.Tensor,
timestep: Optional[torch.Tensor] = None,
class_labels: Optional[torch.LongTensor] = None,
hidden_dtype: Optional[torch.dtype] = None,
emb: Optional[torch.Tensor] = None,
hidden_length: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
if hidden_length is not None:
return self.forward_with_pad(x, timestep, class_labels, hidden_dtype, emb, hidden_length)
if self.emb is not None:
emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
emb = self.linear(self.silu(emb))
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
class AdaLayerNormZeroSingle(nn.Module):
r"""
Norm layer adaptive layer norm zero (adaLN-Zero).
Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_embeddings (`int`): The size of the embeddings dictionary.
"""
def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
if norm_type == "layer_norm":
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
else:
raise ValueError(
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
)
def forward_with_pad(
self,
x: torch.Tensor,
emb: Optional[torch.Tensor] = None,
hidden_length: Optional[torch.Tensor] = None,
):
emb = self.linear(self.silu(emb))
batch_emb = torch.zeros_like(x).repeat(1, 1, 3)
i_sum = 0
num_stages = len(hidden_length)
for i_p, length in enumerate(hidden_length):
batch_emb[:, i_sum:i_sum+length] = emb[i_p::num_stages][:,None]
i_sum += length
batch_shift_msa, batch_scale_msa, batch_gate_msa = batch_emb.chunk(3, dim=2)
x = self.norm(x) * (1 + batch_scale_msa) + batch_shift_msa
return x, batch_gate_msa
def forward(
self,
x: torch.Tensor,
emb: Optional[torch.Tensor] = None,
hidden_length: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
if hidden_length is not None:
return self.forward_with_pad(x, emb, hidden_length)
emb = self.linear(self.silu(emb))
shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x, gate_msa
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment