Commit 0e56f303 authored by mashun's avatar mashun
Browse files

pyramid-flow

parents
Pipeline #2007 canceled with stages
from typing import Any, Dict, List, Optional, Union
import torch
import os
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from tqdm import tqdm
from diffusers.utils.torch_utils import randn_tensor
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
from diffusers.utils import is_torch_version
from .modeling_normalization import AdaLayerNormContinuous
from .modeling_embedding import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings
from .modeling_flux_block import FluxTransformerBlock, FluxSingleTransformerBlock
from trainer_misc import (
is_sequence_parallel_initialized,
get_sequence_parallel_group,
get_sequence_parallel_world_size,
get_sequence_parallel_rank,
all_to_all,
)
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
assert dim % 2 == 0, "The dimension must be even."
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
omega = 1.0 / (theta**scale)
batch_size, seq_length = pos.shape
out = torch.einsum("...n,d->...nd", pos, omega)
cos_out = torch.cos(out)
sin_out = torch.sin(out)
stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
return out.float()
class EmbedND(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: List[int]):
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
emb = torch.cat(
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
dim=-3,
)
return emb.unsqueeze(2)
class PyramidFluxTransformer(ModelMixin, ConfigMixin):
"""
The Transformer model introduced in Flux.
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
Parameters:
patch_size (`int`): Patch size to turn the input data into small patches.
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
patch_size: int = 1,
in_channels: int = 64,
num_layers: int = 19,
num_single_layers: int = 38,
attention_head_dim: int = 64,
num_attention_heads: int = 24,
joint_attention_dim: int = 4096,
pooled_projection_dim: int = 768,
axes_dims_rope: List[int] = [16, 24, 24],
use_flash_attn: bool = False,
use_temporal_causal: bool = True,
interp_condition_pos: bool = True,
use_gradient_checkpointing: bool = False,
gradient_checkpointing_ratio: float = 0.6,
):
super().__init__()
self.out_channels = in_channels
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
self.pos_embed = EmbedND(dim=self.inner_dim, theta=10000, axes_dim=axes_dims_rope)
self.time_text_embed = CombinedTimestepTextProjEmbeddings(
embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
)
self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
self.transformer_blocks = nn.ModuleList(
[
FluxTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
use_flash_attn=use_flash_attn,
)
for i in range(self.config.num_layers)
]
)
self.single_transformer_blocks = nn.ModuleList(
[
FluxSingleTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
use_flash_attn=use_flash_attn,
)
for i in range(self.config.num_single_layers)
]
)
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
self.gradient_checkpointing = use_gradient_checkpointing
self.gradient_checkpointing_ratio = gradient_checkpointing_ratio
self.use_temporal_causal = use_temporal_causal
if self.use_temporal_causal:
print("Using temporal causal attention")
self.use_flash_attn = use_flash_attn
if self.use_flash_attn:
print("Using Flash attention")
self.patch_size = 2 # hard-code for now
# init weights
self.initialize_weights()
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, (nn.Linear, nn.Conv2d, nn.Conv3d)):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize all the conditioning to normal init
nn.init.normal_(self.time_text_embed.timestep_embedder.linear_1.weight, std=0.02)
nn.init.normal_(self.time_text_embed.timestep_embedder.linear_2.weight, std=0.02)
nn.init.normal_(self.time_text_embed.text_embedder.linear_1.weight, std=0.02)
nn.init.normal_(self.time_text_embed.text_embedder.linear_2.weight, std=0.02)
nn.init.normal_(self.context_embedder.weight, std=0.02)
# Zero-out adaLN modulation layers in DiT blocks:
for block in self.transformer_blocks:
nn.init.constant_(block.norm1.linear.weight, 0)
nn.init.constant_(block.norm1.linear.bias, 0)
nn.init.constant_(block.norm1_context.linear.weight, 0)
nn.init.constant_(block.norm1_context.linear.bias, 0)
for block in self.single_transformer_blocks:
nn.init.constant_(block.norm.linear.weight, 0)
nn.init.constant_(block.norm.linear.bias, 0)
# Zero-out output layers:
nn.init.constant_(self.norm_out.linear.weight, 0)
nn.init.constant_(self.norm_out.linear.bias, 0)
nn.init.constant_(self.proj_out.weight, 0)
nn.init.constant_(self.proj_out.bias, 0)
@torch.no_grad()
def _prepare_image_ids(self, batch_size, temp, height, width, train_height, train_width, device, start_time_stamp=0):
latent_image_ids = torch.zeros(temp, height, width, 3)
# Temporal Rope
latent_image_ids[..., 0] = latent_image_ids[..., 0] + torch.arange(start_time_stamp, start_time_stamp + temp)[:, None, None]
# height Rope
if height != train_height:
height_pos = F.interpolate(torch.arange(train_height)[None, None, :].float(), height, mode='linear').squeeze(0, 1)
else:
height_pos = torch.arange(train_height).float()
latent_image_ids[..., 1] = latent_image_ids[..., 1] + height_pos[None, :, None]
# width rope
if width != train_width:
width_pos = F.interpolate(torch.arange(train_width)[None, None, :].float(), width, mode='linear').squeeze(0, 1)
else:
width_pos = torch.arange(train_width).float()
latent_image_ids[..., 2] = latent_image_ids[..., 2] + width_pos[None, None, :]
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1, 1)
latent_image_ids = rearrange(latent_image_ids, 'b t h w c -> b (t h w) c')
return latent_image_ids.to(device=device)
@torch.no_grad()
def _prepare_pyramid_image_ids(self, sample, batch_size, device):
image_ids_list = []
for i_b, sample_ in enumerate(sample):
if not isinstance(sample_, list):
sample_ = [sample_]
cur_image_ids = []
start_time_stamp = 0
train_height = sample_[-1].shape[-2] // self.patch_size
train_width = sample_[-1].shape[-1] // self.patch_size
for clip_ in sample_:
_, _, temp, height, width = clip_.shape
height = height // self.patch_size
width = width // self.patch_size
cur_image_ids.append(self._prepare_image_ids(batch_size, temp, height, width, train_height, train_width, device, start_time_stamp=start_time_stamp))
start_time_stamp += temp
cur_image_ids = torch.cat(cur_image_ids, dim=1)
image_ids_list.append(cur_image_ids)
return image_ids_list
def merge_input(self, sample, encoder_hidden_length, encoder_attention_mask):
"""
Merge the input video with different resolutions into one sequence
Sample: From low resolution to high resolution
"""
if isinstance(sample[0], list):
device = sample[0][-1].device
pad_batch_size = sample[0][-1].shape[0]
else:
device = sample[0].device
pad_batch_size = sample[0].shape[0]
num_stages = len(sample)
height_list = [];width_list = [];temp_list = []
trainable_token_list = []
for i_b, sample_ in enumerate(sample):
if isinstance(sample_, list):
sample_ = sample_[-1]
_, _, temp, height, width = sample_.shape
height = height // self.patch_size
width = width // self.patch_size
temp_list.append(temp)
height_list.append(height)
width_list.append(width)
trainable_token_list.append(height * width * temp)
# prepare the RoPE IDs,
image_ids_list = self._prepare_pyramid_image_ids(sample, pad_batch_size, device)
text_ids = torch.zeros(pad_batch_size, encoder_attention_mask.shape[1], 3).to(device=device)
input_ids_list = [torch.cat([text_ids, image_ids], dim=1) for image_ids in image_ids_list]
image_rotary_emb = [self.pos_embed(input_ids) for input_ids in input_ids_list] # [bs, seq_len, 1, head_dim // 2, 2, 2]
if is_sequence_parallel_initialized():
sp_group = get_sequence_parallel_group()
sp_group_size = get_sequence_parallel_world_size()
concat_output = True if self.training else False
image_rotary_emb = [all_to_all(x_.repeat(1, 1, sp_group_size, 1, 1, 1), sp_group, sp_group_size, scatter_dim=2, gather_dim=0, concat_output=concat_output) for x_ in image_rotary_emb]
input_ids_list = [all_to_all(input_ids.repeat(1, 1, sp_group_size), sp_group, sp_group_size, scatter_dim=2, gather_dim=0, concat_output=concat_output) for input_ids in input_ids_list]
hidden_states, hidden_length = [], []
for sample_ in sample:
video_tokens = []
for each_latent in sample_:
each_latent = rearrange(each_latent, 'b c t h w -> b t h w c')
each_latent = rearrange(each_latent, 'b t (h p1) (w p2) c -> b (t h w) (p1 p2 c)', p1=self.patch_size, p2=self.patch_size)
video_tokens.append(each_latent)
video_tokens = torch.cat(video_tokens, dim=1)
video_tokens = self.x_embedder(video_tokens)
hidden_states.append(video_tokens)
hidden_length.append(video_tokens.shape[1])
# prepare the attention mask
if self.use_flash_attn:
attention_mask = None
indices_list = []
for i_p, length in enumerate(hidden_length):
pad_attention_mask = torch.ones((pad_batch_size, length), dtype=encoder_attention_mask.dtype).to(device)
pad_attention_mask = torch.cat([encoder_attention_mask[i_p::num_stages], pad_attention_mask], dim=1)
if is_sequence_parallel_initialized():
sp_group = get_sequence_parallel_group()
sp_group_size = get_sequence_parallel_world_size()
pad_attention_mask = all_to_all(pad_attention_mask.unsqueeze(2).repeat(1, 1, sp_group_size), sp_group, sp_group_size, scatter_dim=2, gather_dim=0)
pad_attention_mask = pad_attention_mask.squeeze(2)
seqlens_in_batch = pad_attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(pad_attention_mask.flatten(), as_tuple=False).flatten()
indices_list.append(
{
'indices': indices,
'seqlens_in_batch': seqlens_in_batch,
}
)
encoder_attention_mask = indices_list
else:
assert encoder_attention_mask.shape[1] == encoder_hidden_length
real_batch_size = encoder_attention_mask.shape[0]
# prepare text ids
text_ids = torch.arange(1, real_batch_size + 1, dtype=encoder_attention_mask.dtype).unsqueeze(1).repeat(1, encoder_hidden_length)
text_ids = text_ids.to(device)
text_ids[encoder_attention_mask == 0] = 0
# prepare image ids
image_ids = torch.arange(1, real_batch_size + 1, dtype=encoder_attention_mask.dtype).unsqueeze(1).repeat(1, max(hidden_length))
image_ids = image_ids.to(device)
image_ids_list = []
for i_p, length in enumerate(hidden_length):
image_ids_list.append(image_ids[i_p::num_stages][:, :length])
if is_sequence_parallel_initialized():
sp_group = get_sequence_parallel_group()
sp_group_size = get_sequence_parallel_world_size()
concat_output = True if self.training else False
text_ids = all_to_all(text_ids.unsqueeze(2).repeat(1, 1, sp_group_size), sp_group, sp_group_size, scatter_dim=2, gather_dim=0, concat_output=concat_output).squeeze(2)
image_ids_list = [all_to_all(image_ids_.unsqueeze(2).repeat(1, 1, sp_group_size), sp_group, sp_group_size, scatter_dim=2, gather_dim=0, concat_output=concat_output).squeeze(2) for image_ids_ in image_ids_list]
attention_mask = []
for i_p in range(len(hidden_length)):
image_ids = image_ids_list[i_p]
token_ids = torch.cat([text_ids[i_p::num_stages], image_ids], dim=1)
stage_attention_mask = rearrange(token_ids, 'b i -> b 1 i 1') == rearrange(token_ids, 'b j -> b 1 1 j') # [bs, 1, q_len, k_len]
if self.use_temporal_causal:
input_order_ids = input_ids_list[i_p][:,:,0]
temporal_causal_mask = rearrange(input_order_ids, 'b i -> b 1 i 1') >= rearrange(input_order_ids, 'b j -> b 1 1 j')
stage_attention_mask = stage_attention_mask & temporal_causal_mask
attention_mask.append(stage_attention_mask)
return hidden_states, hidden_length, temp_list, height_list, width_list, trainable_token_list, encoder_attention_mask, attention_mask, image_rotary_emb
def split_output(self, batch_hidden_states, hidden_length, temps, heights, widths, trainable_token_list):
# To split the hidden states
batch_size = batch_hidden_states.shape[0]
output_hidden_list = []
batch_hidden_states = torch.split(batch_hidden_states, hidden_length, dim=1)
if is_sequence_parallel_initialized():
sp_group_size = get_sequence_parallel_world_size()
if self.training:
batch_size = batch_size // sp_group_size
for i_p, length in enumerate(hidden_length):
width, height, temp = widths[i_p], heights[i_p], temps[i_p]
trainable_token_num = trainable_token_list[i_p]
hidden_states = batch_hidden_states[i_p]
if is_sequence_parallel_initialized():
sp_group = get_sequence_parallel_group()
sp_group_size = get_sequence_parallel_world_size()
if not self.training:
hidden_states = hidden_states.repeat(sp_group_size, 1, 1)
hidden_states = all_to_all(hidden_states, sp_group, sp_group_size, scatter_dim=0, gather_dim=1)
# only the trainable token are taking part in loss computation
hidden_states = hidden_states[:, -trainable_token_num:]
# unpatchify
hidden_states = hidden_states.reshape(
shape=(batch_size, temp, height, width, self.patch_size, self.patch_size, self.out_channels // 4)
)
hidden_states = rearrange(hidden_states, "b t h w p1 p2 c -> b t (h p1) (w p2) c")
hidden_states = rearrange(hidden_states, "b t h w c -> b c t h w")
output_hidden_list.append(hidden_states)
return output_hidden_list
def forward(
self,
sample: torch.FloatTensor, # [num_stages]
encoder_hidden_states: torch.Tensor = None,
encoder_attention_mask: torch.FloatTensor = None,
pooled_projections: torch.Tensor = None,
timestep_ratio: torch.LongTensor = None,
):
temb = self.time_text_embed(timestep_ratio, pooled_projections)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
encoder_hidden_length = encoder_hidden_states.shape[1]
# Get the input sequence
hidden_states, hidden_length, temps, heights, widths, trainable_token_list, encoder_attention_mask, attention_mask, \
image_rotary_emb = self.merge_input(sample, encoder_hidden_length, encoder_attention_mask)
# split the long latents if necessary
if is_sequence_parallel_initialized():
sp_group = get_sequence_parallel_group()
sp_group_size = get_sequence_parallel_world_size()
concat_output = True if self.training else False
# sync the input hidden states
batch_hidden_states = []
for i_p, hidden_states_ in enumerate(hidden_states):
assert hidden_states_.shape[1] % sp_group_size == 0, "The sequence length should be divided by sequence parallel size"
hidden_states_ = all_to_all(hidden_states_, sp_group, sp_group_size, scatter_dim=1, gather_dim=0, concat_output=concat_output)
hidden_length[i_p] = hidden_length[i_p] // sp_group_size
batch_hidden_states.append(hidden_states_)
# sync the encoder hidden states
hidden_states = torch.cat(batch_hidden_states, dim=1)
encoder_hidden_states = all_to_all(encoder_hidden_states, sp_group, sp_group_size, scatter_dim=1, gather_dim=0, concat_output=concat_output)
temb = all_to_all(temb.unsqueeze(1).repeat(1, sp_group_size, 1), sp_group, sp_group_size, scatter_dim=1, gather_dim=0, concat_output=concat_output)
temb = temb.squeeze(1)
else:
hidden_states = torch.cat(hidden_states, dim=1)
for index_block, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing and (index_block <= int(len(self.transformer_blocks) * self.gradient_checkpointing_ratio)):
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
encoder_attention_mask,
temb,
attention_mask,
hidden_length,
image_rotary_emb,
**ckpt_kwargs,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
temb=temb,
attention_mask=attention_mask,
hidden_length=hidden_length,
image_rotary_emb=image_rotary_emb,
)
# remerge for single attention block
num_stages = len(hidden_length)
batch_hidden_states = list(torch.split(hidden_states, hidden_length, dim=1))
concat_hidden_length = []
if is_sequence_parallel_initialized():
sp_group = get_sequence_parallel_group()
sp_group_size = get_sequence_parallel_world_size()
encoder_hidden_states = all_to_all(encoder_hidden_states, sp_group, sp_group_size, scatter_dim=0, gather_dim=1)
for i_p in range(len(hidden_length)):
if is_sequence_parallel_initialized():
sp_group = get_sequence_parallel_group()
sp_group_size = get_sequence_parallel_world_size()
batch_hidden_states[i_p] = all_to_all(batch_hidden_states[i_p], sp_group, sp_group_size, scatter_dim=0, gather_dim=1)
batch_hidden_states[i_p] = torch.cat([encoder_hidden_states[i_p::num_stages], batch_hidden_states[i_p]], dim=1)
if is_sequence_parallel_initialized():
sp_group = get_sequence_parallel_group()
sp_group_size = get_sequence_parallel_world_size()
batch_hidden_states[i_p] = all_to_all(batch_hidden_states[i_p], sp_group, sp_group_size, scatter_dim=1, gather_dim=0)
concat_hidden_length.append(batch_hidden_states[i_p].shape[1])
hidden_states = torch.cat(batch_hidden_states, dim=1)
for index_block, block in enumerate(self.single_transformer_blocks):
if self.training and self.gradient_checkpointing and (index_block <= int(len(self.single_transformer_blocks) * self.gradient_checkpointing_ratio)):
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
temb,
encoder_attention_mask,
attention_mask,
concat_hidden_length,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states=hidden_states,
temb=temb,
encoder_attention_mask=encoder_attention_mask, # used for
attention_mask=attention_mask,
hidden_length=concat_hidden_length,
image_rotary_emb=image_rotary_emb,
)
batch_hidden_states = list(torch.split(hidden_states, concat_hidden_length, dim=1))
for i_p in range(len(concat_hidden_length)):
if is_sequence_parallel_initialized():
sp_group = get_sequence_parallel_group()
sp_group_size = get_sequence_parallel_world_size()
batch_hidden_states[i_p] = all_to_all(batch_hidden_states[i_p], sp_group, sp_group_size, scatter_dim=0, gather_dim=1)
batch_hidden_states[i_p] = batch_hidden_states[i_p][:, encoder_hidden_length :, ...]
if is_sequence_parallel_initialized():
sp_group = get_sequence_parallel_group()
sp_group_size = get_sequence_parallel_world_size()
batch_hidden_states[i_p] = all_to_all(batch_hidden_states[i_p], sp_group, sp_group_size, scatter_dim=1, gather_dim=0)
hidden_states = torch.cat(batch_hidden_states, dim=1)
hidden_states = self.norm_out(hidden_states, temb, hidden_length=hidden_length)
hidden_states = self.proj_out(hidden_states)
output = self.split_output(hidden_states, hidden_length, temps, heights, widths, trainable_token_list)
return output
\ No newline at end of file
import torch
import torch.nn as nn
import os
from transformers import (
CLIPTextModel,
CLIPTokenizer,
T5EncoderModel,
T5TokenizerFast,
)
from typing import Any, Callable, Dict, List, Optional, Union
class FluxTextEncoderWithMask(nn.Module):
def __init__(self, model_path, torch_dtype):
super().__init__()
# CLIP-G
self.tokenizer = CLIPTokenizer.from_pretrained(os.path.join(model_path, 'tokenizer'), torch_dtype=torch_dtype)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
)
self.text_encoder = CLIPTextModel.from_pretrained(os.path.join(model_path, 'text_encoder'), torch_dtype=torch_dtype)
# T5
self.tokenizer_2 = T5TokenizerFast.from_pretrained(os.path.join(model_path, 'tokenizer_2'))
self.text_encoder_2 = T5EncoderModel.from_pretrained(os.path.join(model_path, 'text_encoder_2'), torch_dtype=torch_dtype)
self._freeze()
def _freeze(self):
for param in self.parameters():
param.requires_grad = False
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
num_images_per_prompt: int = 1,
max_sequence_length: int = 128,
device: Optional[torch.device] = None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
text_inputs = self.tokenizer_2(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_length=False,
return_overflowing_tokens=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_attention_mask = text_inputs.attention_mask
prompt_attention_mask = prompt_attention_mask.to(device)
prompt_embeds = self.text_encoder_2(text_input_ids.to(device), attention_mask=prompt_attention_mask, output_hidden_states=False)[0]
dtype = self.text_encoder_2.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
_, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
return prompt_embeds, prompt_attention_mask
def _get_clip_prompt_embeds(
self,
prompt: Union[str, List[str]],
num_images_per_prompt: int = 1,
device: Optional[torch.device] = None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer_max_length,
truncation=True,
return_overflowing_tokens=False,
return_length=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
# Use pooled output of CLIPTextModel
prompt_embeds = prompt_embeds.pooler_output
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds
def encode_prompt(self,
prompt,
num_images_per_prompt=1,
device=None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
pooled_prompt_embeds = self._get_clip_prompt_embeds(
prompt=prompt,
device=device,
num_images_per_prompt=num_images_per_prompt,
)
prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
)
return prompt_embeds, prompt_attention_mask, pooled_prompt_embeds
def forward(self, input_prompts, device):
with torch.no_grad():
prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = self.encode_prompt(input_prompts, 1, device=device)
return prompt_embeds, prompt_attention_mask, pooled_prompt_embeds
\ No newline at end of file
from .modeling_text_encoder import SD3TextEncoderWithMask
from .modeling_pyramid_mmdit import PyramidDiffusionMMDiT
from .modeling_mmdit_block import JointTransformerBlock
\ No newline at end of file
from typing import Any, Dict, Optional, Union
import torch
import torch.nn as nn
import numpy as np
import math
from diffusers.models.activations import get_activation
from einops import rearrange
def get_1d_sincos_pos_embed(
embed_dim, num_frames, cls_token=False, extra_tokens=0,
):
t = np.arange(num_frames, dtype=np.float32)
pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, t) # (T, D)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed(
embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
):
"""
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
if isinstance(grid_size, int):
grid_size = (grid_size, grid_size)
grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be divisible by 2")
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
"""
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be divisible by 2")
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
def get_timestep_embedding(
timesteps: torch.Tensor,
embedding_dim: int,
flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional.
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
embeddings. :return: an [N x dim] Tensor of positional embeddings.
"""
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange(
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
# scale embeddings
emb = scale * emb
# concat sine and cosine embeddings
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
# flip sine and cosine embeddings
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
# zero pad
if embedding_dim % 2 == 1:
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
class Timesteps(nn.Module):
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
def forward(self, timesteps):
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
)
return t_emb
class TimestepEmbedding(nn.Module):
def __init__(
self,
in_channels: int,
time_embed_dim: int,
act_fn: str = "silu",
out_dim: int = None,
post_act_fn: Optional[str] = None,
sample_proj_bias=True,
):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
self.act = get_activation(act_fn)
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim, sample_proj_bias)
def forward(self, sample):
sample = self.linear_1(sample)
sample = self.act(sample)
sample = self.linear_2(sample)
return sample
class TextProjection(nn.Module):
def __init__(self, in_features, hidden_size, act_fn="silu"):
super().__init__()
self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
self.act_1 = get_activation(act_fn)
self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True)
def forward(self, caption):
hidden_states = self.linear_1(caption)
hidden_states = self.act_1(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class CombinedTimestepConditionEmbeddings(nn.Module):
def __init__(self, embedding_dim, pooled_projection_dim):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.text_embedder = TextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
def forward(self, timestep, pooled_projection):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
pooled_projections = self.text_embedder(pooled_projection)
conditioning = timesteps_emb + pooled_projections
return conditioning
class CombinedTimestepEmbeddings(nn.Module):
def __init__(self, embedding_dim):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
def forward(self, timestep):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj) # (N, D)
return timesteps_emb
class PatchEmbed3D(nn.Module):
"""Support the 3D Tensor input"""
def __init__(
self,
height=128,
width=128,
patch_size=2,
in_channels=16,
embed_dim=1536,
layer_norm=False,
bias=True,
interpolation_scale=1,
pos_embed_type="sincos",
temp_pos_embed_type='rope',
pos_embed_max_size=192, # For SD3 cropping
max_num_frames=64,
add_temp_pos_embed=False,
interp_condition_pos=False,
):
super().__init__()
num_patches = (height // patch_size) * (width // patch_size)
self.layer_norm = layer_norm
self.pos_embed_max_size = pos_embed_max_size
self.proj = nn.Conv2d(
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
)
if layer_norm:
self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
else:
self.norm = None
self.patch_size = patch_size
self.height, self.width = height // patch_size, width // patch_size
self.base_size = height // patch_size
self.interpolation_scale = interpolation_scale
self.add_temp_pos_embed = add_temp_pos_embed
# Calculate positional embeddings based on max size or default
if pos_embed_max_size:
grid_size = pos_embed_max_size
else:
grid_size = int(num_patches**0.5)
if pos_embed_type is None:
self.pos_embed = None
elif pos_embed_type == "sincos":
pos_embed = get_2d_sincos_pos_embed(
embed_dim, grid_size, base_size=self.base_size, interpolation_scale=self.interpolation_scale
)
persistent = True if pos_embed_max_size else False
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=persistent)
if add_temp_pos_embed and temp_pos_embed_type == 'sincos':
time_pos_embed = get_1d_sincos_pos_embed(embed_dim, max_num_frames)
self.register_buffer("temp_pos_embed", torch.from_numpy(time_pos_embed).float().unsqueeze(0), persistent=True)
elif pos_embed_type == "rope":
print("Using the rotary position embedding")
else:
raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}")
self.pos_embed_type = pos_embed_type
self.temp_pos_embed_type = temp_pos_embed_type
self.interp_condition_pos = interp_condition_pos
def cropped_pos_embed(self, height, width, ori_height, ori_width):
"""Crops positional embeddings for SD3 compatibility."""
if self.pos_embed_max_size is None:
raise ValueError("`pos_embed_max_size` must be set for cropping.")
height = height // self.patch_size
width = width // self.patch_size
ori_height = ori_height // self.patch_size
ori_width = ori_width // self.patch_size
assert ori_height >= height, "The ori_height needs >= height"
assert ori_width >= width, "The ori_width needs >= width"
if height > self.pos_embed_max_size:
raise ValueError(
f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
)
if width > self.pos_embed_max_size:
raise ValueError(
f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
)
if self.interp_condition_pos:
top = (self.pos_embed_max_size - ori_height) // 2
left = (self.pos_embed_max_size - ori_width) // 2
spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1)
spatial_pos_embed = spatial_pos_embed[:, top : top + ori_height, left : left + ori_width, :] # [b h w c]
if ori_height != height or ori_width != width:
spatial_pos_embed = spatial_pos_embed.permute(0, 3, 1, 2)
spatial_pos_embed = torch.nn.functional.interpolate(spatial_pos_embed, size=(height, width), mode='bilinear')
spatial_pos_embed = spatial_pos_embed.permute(0, 2, 3, 1)
else:
top = (self.pos_embed_max_size - height) // 2
left = (self.pos_embed_max_size - width) // 2
spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1)
spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :]
spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
return spatial_pos_embed
def forward_func(self, latent, time_index=0, ori_height=None, ori_width=None):
if self.pos_embed_max_size is not None:
height, width = latent.shape[-2:]
else:
height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
bs = latent.shape[0]
temp = latent.shape[2]
latent = rearrange(latent, 'b c t h w -> (b t) c h w')
latent = self.proj(latent)
latent = latent.flatten(2).transpose(1, 2) # (BT)CHW -> (BT)NC
if self.layer_norm:
latent = self.norm(latent)
if self.pos_embed_type == 'sincos':
# Spatial position embedding, Interpolate or crop positional embeddings as needed
if self.pos_embed_max_size:
pos_embed = self.cropped_pos_embed(height, width, ori_height, ori_width)
else:
raise NotImplementedError("Not implemented sincos pos embed without sd3 max pos crop")
if self.height != height or self.width != width:
pos_embed = get_2d_sincos_pos_embed(
embed_dim=self.pos_embed.shape[-1],
grid_size=(height, width),
base_size=self.base_size,
interpolation_scale=self.interpolation_scale,
)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device)
else:
pos_embed = self.pos_embed
if self.add_temp_pos_embed and self.temp_pos_embed_type == 'sincos':
latent_dtype = latent.dtype
latent = latent + pos_embed
latent = rearrange(latent, '(b t) n c -> (b n) t c', t=temp)
latent = latent + self.temp_pos_embed[:, time_index:time_index + temp, :]
latent = latent.to(latent_dtype)
latent = rearrange(latent, '(b n) t c -> b t n c', b=bs)
else:
latent = (latent + pos_embed).to(latent.dtype)
latent = rearrange(latent, '(b t) n c -> b t n c', b=bs, t=temp)
else:
assert self.pos_embed_type == "rope", "Only supporting the sincos and rope embedding"
latent = rearrange(latent, '(b t) n c -> b t n c', b=bs, t=temp)
return latent
def forward(self, latent):
"""
Arguments:
past_condition_latents (Torch.FloatTensor): The past latent during the generation
flatten_input (bool): True indicate flatten the latent into 1D sequence
"""
if isinstance(latent, list):
output_list = []
for latent_ in latent:
if not isinstance(latent_, list):
latent_ = [latent_]
output_latent = []
time_index = 0
ori_height, ori_width = latent_[-1].shape[-2:]
for each_latent in latent_:
hidden_state = self.forward_func(each_latent, time_index=time_index, ori_height=ori_height, ori_width=ori_width)
time_index += each_latent.shape[2]
hidden_state = rearrange(hidden_state, "b t n c -> b (t n) c")
output_latent.append(hidden_state)
output_latent = torch.cat(output_latent, dim=1)
output_list.append(output_latent)
return output_list
else:
hidden_states = self.forward_func(latent)
hidden_states = rearrange(hidden_states, "b t n c -> b (t n) c")
return hidden_states
\ No newline at end of file
from typing import Dict, Optional, Tuple, List
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from diffusers.models.activations import GEGLU, GELU, ApproximateGELU
try:
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
from flash_attn.bert_padding import pad_input, unpad_input, index_first_axis
from flash_attn.flash_attn_interface import flash_attn_varlen_func
except:
flash_attn_func = None
flash_attn_qkvpacked_func = None
flash_attn_varlen_func = None
from trainer_misc import (
is_sequence_parallel_initialized,
get_sequence_parallel_group,
get_sequence_parallel_world_size,
all_to_all,
)
from .modeling_normalization import AdaLayerNormZero, AdaLayerNormContinuous, RMSNorm
class FeedForward(nn.Module):
r"""
A feed-forward layer.
Parameters:
dim (`int`): The number of channels in the input.
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
"""
def __init__(
self,
dim: int,
dim_out: Optional[int] = None,
mult: int = 4,
dropout: float = 0.0,
activation_fn: str = "geglu",
final_dropout: bool = False,
inner_dim=None,
bias: bool = True,
):
super().__init__()
if inner_dim is None:
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
if activation_fn == "gelu":
act_fn = GELU(dim, inner_dim, bias=bias)
if activation_fn == "gelu-approximate":
act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
elif activation_fn == "geglu":
act_fn = GEGLU(dim, inner_dim, bias=bias)
elif activation_fn == "geglu-approximate":
act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
self.net = nn.ModuleList([])
# project in
self.net.append(act_fn)
# project dropout
self.net.append(nn.Dropout(dropout))
# project out
self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
if final_dropout:
self.net.append(nn.Dropout(dropout))
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states
class VarlenFlashSelfAttentionWithT5Mask:
def __init__(self):
pass
def apply_rope(self, xq, xk, freqs_cis):
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
def __call__(
self, query, key, value, encoder_query, encoder_key, encoder_value,
heads, scale, hidden_length=None, image_rotary_emb=None, encoder_attention_mask=None,
):
assert encoder_attention_mask is not None, "The encoder-hidden mask needed to be set"
batch_size = query.shape[0]
output_hidden = torch.zeros_like(query)
output_encoder_hidden = torch.zeros_like(encoder_query)
encoder_length = encoder_query.shape[1]
qkv_list = []
num_stages = len(hidden_length)
encoder_qkv = torch.stack([encoder_query, encoder_key, encoder_value], dim=2) # [bs, sub_seq, 3, head, head_dim]
qkv = torch.stack([query, key, value], dim=2) # [bs, sub_seq, 3, head, head_dim]
i_sum = 0
for i_p, length in enumerate(hidden_length):
encoder_qkv_tokens = encoder_qkv[i_p::num_stages]
qkv_tokens = qkv[:, i_sum:i_sum+length]
concat_qkv_tokens = torch.cat([encoder_qkv_tokens, qkv_tokens], dim=1) # [bs, tot_seq, 3, nhead, dim]
if image_rotary_emb is not None:
concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1] = self.apply_rope(concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1], image_rotary_emb[i_p])
indices = encoder_attention_mask[i_p]['indices']
qkv_list.append(index_first_axis(rearrange(concat_qkv_tokens, "b s ... -> (b s) ..."), indices))
i_sum += length
token_lengths = [x_.shape[0] for x_ in qkv_list]
qkv = torch.cat(qkv_list, dim=0)
query, key, value = qkv.unbind(1)
cu_seqlens = torch.cat([x_['seqlens_in_batch'] for x_ in encoder_attention_mask], dim=0)
max_seqlen_q = cu_seqlens.max().item()
max_seqlen_k = max_seqlen_q
cu_seqlens_q = F.pad(torch.cumsum(cu_seqlens, dim=0, dtype=torch.int32), (1, 0))
cu_seqlens_k = cu_seqlens_q.clone()
output = flash_attn_varlen_func(
query,
key,
value,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
dropout_p=0.0,
causal=False,
softmax_scale=scale,
)
# To merge the tokens
i_sum = 0;token_sum = 0
for i_p, length in enumerate(hidden_length):
tot_token_num = token_lengths[i_p]
stage_output = output[token_sum : token_sum + tot_token_num]
stage_output = pad_input(stage_output, encoder_attention_mask[i_p]['indices'], batch_size, encoder_length + length)
stage_encoder_hidden_output = stage_output[:, :encoder_length]
stage_hidden_output = stage_output[:, encoder_length:]
output_hidden[:, i_sum:i_sum+length] = stage_hidden_output
output_encoder_hidden[i_p::num_stages] = stage_encoder_hidden_output
token_sum += tot_token_num
i_sum += length
output_hidden = output_hidden.flatten(2, 3)
output_encoder_hidden = output_encoder_hidden.flatten(2, 3)
return output_hidden, output_encoder_hidden
class SequenceParallelVarlenFlashSelfAttentionWithT5Mask:
def __init__(self):
pass
def apply_rope(self, xq, xk, freqs_cis):
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
def __call__(
self, query, key, value, encoder_query, encoder_key, encoder_value,
heads, scale, hidden_length=None, image_rotary_emb=None, encoder_attention_mask=None,
):
assert encoder_attention_mask is not None, "The encoder-hidden mask needed to be set"
batch_size = query.shape[0]
qkv_list = []
num_stages = len(hidden_length)
encoder_qkv = torch.stack([encoder_query, encoder_key, encoder_value], dim=2) # [bs, sub_seq, 3, head, head_dim]
qkv = torch.stack([query, key, value], dim=2) # [bs, sub_seq, 3, head, head_dim]
# To sync the encoder query, key and values
sp_group = get_sequence_parallel_group()
sp_group_size = get_sequence_parallel_world_size()
encoder_qkv = all_to_all(encoder_qkv, sp_group, sp_group_size, scatter_dim=3, gather_dim=1) # [bs, seq, 3, sub_head, head_dim]
output_hidden = torch.zeros_like(qkv[:,:,0])
output_encoder_hidden = torch.zeros_like(encoder_qkv[:,:,0])
encoder_length = encoder_qkv.shape[1]
i_sum = 0
for i_p, length in enumerate(hidden_length):
# get the query, key, value from padding sequence
encoder_qkv_tokens = encoder_qkv[i_p::num_stages]
qkv_tokens = qkv[:, i_sum:i_sum+length]
qkv_tokens = all_to_all(qkv_tokens, sp_group, sp_group_size, scatter_dim=3, gather_dim=1) # [bs, seq, 3, sub_head, head_dim]
concat_qkv_tokens = torch.cat([encoder_qkv_tokens, qkv_tokens], dim=1) # [bs, pad_seq, 3, nhead, dim]
if image_rotary_emb is not None:
concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1] = self.apply_rope(concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1], image_rotary_emb[i_p])
indices = encoder_attention_mask[i_p]['indices']
qkv_list.append(index_first_axis(rearrange(concat_qkv_tokens, "b s ... -> (b s) ..."), indices))
i_sum += length
token_lengths = [x_.shape[0] for x_ in qkv_list]
qkv = torch.cat(qkv_list, dim=0)
query, key, value = qkv.unbind(1)
cu_seqlens = torch.cat([x_['seqlens_in_batch'] for x_ in encoder_attention_mask], dim=0)
max_seqlen_q = cu_seqlens.max().item()
max_seqlen_k = max_seqlen_q
cu_seqlens_q = F.pad(torch.cumsum(cu_seqlens, dim=0, dtype=torch.int32), (1, 0))
cu_seqlens_k = cu_seqlens_q.clone()
output = flash_attn_varlen_func(
query,
key,
value,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
dropout_p=0.0,
causal=False,
softmax_scale=scale,
)
# To merge the tokens
i_sum = 0;token_sum = 0
for i_p, length in enumerate(hidden_length):
tot_token_num = token_lengths[i_p]
stage_output = output[token_sum : token_sum + tot_token_num]
stage_output = pad_input(stage_output, encoder_attention_mask[i_p]['indices'], batch_size, encoder_length + length * sp_group_size)
stage_encoder_hidden_output = stage_output[:, :encoder_length]
stage_hidden_output = stage_output[:, encoder_length:]
stage_hidden_output = all_to_all(stage_hidden_output, sp_group, sp_group_size, scatter_dim=1, gather_dim=2)
output_hidden[:, i_sum:i_sum+length] = stage_hidden_output
output_encoder_hidden[i_p::num_stages] = stage_encoder_hidden_output
token_sum += tot_token_num
i_sum += length
output_encoder_hidden = all_to_all(output_encoder_hidden, sp_group, sp_group_size, scatter_dim=1, gather_dim=2)
output_hidden = output_hidden.flatten(2, 3)
output_encoder_hidden = output_encoder_hidden.flatten(2, 3)
return output_hidden, output_encoder_hidden
class VarlenSelfAttentionWithT5Mask:
"""
For chunk stage attention without using flash attention
"""
def __init__(self):
pass
def apply_rope(self, xq, xk, freqs_cis):
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
def __call__(
self, query, key, value, encoder_query, encoder_key, encoder_value,
heads, scale, hidden_length=None, image_rotary_emb=None, attention_mask=None,
):
assert attention_mask is not None, "The attention mask needed to be set"
encoder_length = encoder_query.shape[1]
num_stages = len(hidden_length)
encoder_qkv = torch.stack([encoder_query, encoder_key, encoder_value], dim=2) # [bs, sub_seq, 3, head, head_dim]
qkv = torch.stack([query, key, value], dim=2) # [bs, sub_seq, 3, head, head_dim]
i_sum = 0
output_encoder_hidden_list = []
output_hidden_list = []
for i_p, length in enumerate(hidden_length):
encoder_qkv_tokens = encoder_qkv[i_p::num_stages]
qkv_tokens = qkv[:, i_sum:i_sum+length]
concat_qkv_tokens = torch.cat([encoder_qkv_tokens, qkv_tokens], dim=1) # [bs, tot_seq, 3, nhead, dim]
if image_rotary_emb is not None:
concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1] = self.apply_rope(concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1], image_rotary_emb[i_p])
query, key, value = concat_qkv_tokens.unbind(2) # [bs, tot_seq, nhead, dim]
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
# with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=False, enable_mem_efficient=True):
stage_hidden_states = F.scaled_dot_product_attention(
query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask[i_p],
)
stage_hidden_states = stage_hidden_states.transpose(1, 2).flatten(2, 3) # [bs, tot_seq, dim]
output_encoder_hidden_list.append(stage_hidden_states[:, :encoder_length])
output_hidden_list.append(stage_hidden_states[:, encoder_length:])
i_sum += length
output_encoder_hidden = torch.stack(output_encoder_hidden_list, dim=1) # [b n s d]
output_encoder_hidden = rearrange(output_encoder_hidden, 'b n s d -> (b n) s d')
output_hidden = torch.cat(output_hidden_list, dim=1)
return output_hidden, output_encoder_hidden
class SequenceParallelVarlenSelfAttentionWithT5Mask:
"""
For chunk stage attention without using flash attention
"""
def __init__(self):
pass
def apply_rope(self, xq, xk, freqs_cis):
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
def __call__(
self, query, key, value, encoder_query, encoder_key, encoder_value,
heads, scale, hidden_length=None, image_rotary_emb=None, attention_mask=None,
):
assert attention_mask is not None, "The attention mask needed to be set"
num_stages = len(hidden_length)
encoder_qkv = torch.stack([encoder_query, encoder_key, encoder_value], dim=2) # [bs, sub_seq, 3, head, head_dim]
qkv = torch.stack([query, key, value], dim=2) # [bs, sub_seq, 3, head, head_dim]
# To sync the encoder query, key and values
sp_group = get_sequence_parallel_group()
sp_group_size = get_sequence_parallel_world_size()
encoder_qkv = all_to_all(encoder_qkv, sp_group, sp_group_size, scatter_dim=3, gather_dim=1) # [bs, seq, 3, sub_head, head_dim]
encoder_length = encoder_qkv.shape[1]
i_sum = 0
output_encoder_hidden_list = []
output_hidden_list = []
for i_p, length in enumerate(hidden_length):
encoder_qkv_tokens = encoder_qkv[i_p::num_stages]
qkv_tokens = qkv[:, i_sum:i_sum+length]
qkv_tokens = all_to_all(qkv_tokens, sp_group, sp_group_size, scatter_dim=3, gather_dim=1) # [bs, seq, 3, sub_head, head_dim]
concat_qkv_tokens = torch.cat([encoder_qkv_tokens, qkv_tokens], dim=1) # [bs, tot_seq, 3, nhead, dim]
if image_rotary_emb is not None:
concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1] = self.apply_rope(concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1], image_rotary_emb[i_p])
query, key, value = concat_qkv_tokens.unbind(2) # [bs, tot_seq, nhead, dim]
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
stage_hidden_states = F.scaled_dot_product_attention(
query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask[i_p],
)
stage_hidden_states = stage_hidden_states.transpose(1, 2) # [bs, tot_seq, nhead, dim]
output_encoder_hidden_list.append(stage_hidden_states[:, :encoder_length])
output_hidden = stage_hidden_states[:, encoder_length:]
output_hidden = all_to_all(output_hidden, sp_group, sp_group_size, scatter_dim=1, gather_dim=2)
output_hidden_list.append(output_hidden)
i_sum += length
output_encoder_hidden = torch.stack(output_encoder_hidden_list, dim=1) # [b n s nhead d]
output_encoder_hidden = rearrange(output_encoder_hidden, 'b n s h d -> (b n) s h d')
output_encoder_hidden = all_to_all(output_encoder_hidden, sp_group, sp_group_size, scatter_dim=1, gather_dim=2)
output_encoder_hidden = output_encoder_hidden.flatten(2, 3)
output_hidden = torch.cat(output_hidden_list, dim=1).flatten(2, 3)
return output_hidden, output_encoder_hidden
class JointAttention(nn.Module):
def __init__(
self,
query_dim: int,
cross_attention_dim: Optional[int] = None,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = False,
qk_norm: Optional[str] = None,
added_kv_proj_dim: Optional[int] = None,
out_bias: bool = True,
eps: float = 1e-5,
out_dim: int = None,
context_pre_only=None,
use_flash_attn=True,
):
"""
Fixing the QKNorm, following the flux, norm the head dimension
"""
super().__init__()
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.query_dim = query_dim
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.use_bias = bias
self.dropout = dropout
self.out_dim = out_dim if out_dim is not None else query_dim
self.context_pre_only = context_pre_only
self.scale = dim_head**-0.5
self.heads = out_dim // dim_head if out_dim is not None else heads
self.added_kv_proj_dim = added_kv_proj_dim
if qk_norm is None:
self.norm_q = None
self.norm_k = None
elif qk_norm == "layer_norm":
self.norm_q = nn.LayerNorm(dim_head, eps=eps)
self.norm_k = nn.LayerNorm(dim_head, eps=eps)
elif qk_norm == 'rms_norm':
self.norm_q = RMSNorm(dim_head, eps=eps)
self.norm_k = RMSNorm(dim_head, eps=eps)
else:
raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'")
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_k = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias)
self.to_v = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias)
if self.added_kv_proj_dim is not None:
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
if qk_norm is None:
self.norm_add_q = None
self.norm_add_k = None
elif qk_norm == "layer_norm":
self.norm_add_q = nn.LayerNorm(dim_head, eps=eps)
self.norm_add_k = nn.LayerNorm(dim_head, eps=eps)
elif qk_norm == 'rms_norm':
self.norm_add_q = RMSNorm(dim_head, eps=eps)
self.norm_add_k = RMSNorm(dim_head, eps=eps)
else:
raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'")
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(nn.Dropout(dropout))
if not self.context_pre_only:
self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)
self.use_flash_attn = use_flash_attn
if flash_attn_func is None:
self.use_flash_attn = False
# print(f"Using flash-attention: {self.use_flash_attn}")
if self.use_flash_attn:
if is_sequence_parallel_initialized():
self.var_flash_attn = SequenceParallelVarlenFlashSelfAttentionWithT5Mask()
else:
self.var_flash_attn = VarlenFlashSelfAttentionWithT5Mask()
else:
if is_sequence_parallel_initialized():
self.var_len_attn = SequenceParallelVarlenSelfAttentionWithT5Mask()
else:
self.var_len_attn = VarlenSelfAttentionWithT5Mask()
def forward(
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
encoder_attention_mask: torch.FloatTensor = None,
attention_mask: torch.FloatTensor = None, # [B, L, S]
hidden_length: torch.Tensor = None,
image_rotary_emb: torch.Tensor = None,
**kwargs,
) -> torch.FloatTensor:
# This function is only used during training
# `sample` projections.
query = self.to_q(hidden_states)
key = self.to_k(hidden_states)
value = self.to_v(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // self.heads
query = query.view(query.shape[0], -1, self.heads, head_dim)
key = key.view(key.shape[0], -1, self.heads, head_dim)
value = value.view(value.shape[0], -1, self.heads, head_dim)
if self.norm_q is not None:
query = self.norm_q(query)
if self.norm_k is not None:
key = self.norm_k(key)
# `context` projections.
encoder_hidden_states_query_proj = self.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
encoder_hidden_states_query_proj.shape[0], -1, self.heads, head_dim
)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
encoder_hidden_states_key_proj.shape[0], -1, self.heads, head_dim
)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
encoder_hidden_states_value_proj.shape[0], -1, self.heads, head_dim
)
if self.norm_add_q is not None:
encoder_hidden_states_query_proj = self.norm_add_q(encoder_hidden_states_query_proj)
if self.norm_add_k is not None:
encoder_hidden_states_key_proj = self.norm_add_k(encoder_hidden_states_key_proj)
# To cat the hidden and encoder hidden, perform attention compuataion, and then split
if self.use_flash_attn:
hidden_states, encoder_hidden_states = self.var_flash_attn(
query, key, value,
encoder_hidden_states_query_proj, encoder_hidden_states_key_proj,
encoder_hidden_states_value_proj, self.heads, self.scale, hidden_length,
image_rotary_emb, encoder_attention_mask,
)
else:
hidden_states, encoder_hidden_states = self.var_len_attn(
query, key, value,
encoder_hidden_states_query_proj, encoder_hidden_states_key_proj,
encoder_hidden_states_value_proj, self.heads, self.scale, hidden_length,
image_rotary_emb, attention_mask,
)
# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
if not self.context_pre_only:
encoder_hidden_states = self.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
class JointTransformerBlock(nn.Module):
r"""
A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
Reference: https://arxiv.org/abs/2403.03206
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
processing of `context` conditions.
"""
def __init__(
self, dim, num_attention_heads, attention_head_dim, qk_norm=None,
context_pre_only=False, use_flash_attn=True,
):
super().__init__()
self.context_pre_only = context_pre_only
context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"
self.norm1 = AdaLayerNormZero(dim)
if context_norm_type == "ada_norm_continous":
self.norm1_context = AdaLayerNormContinuous(
dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm"
)
elif context_norm_type == "ada_norm_zero":
self.norm1_context = AdaLayerNormZero(dim)
else:
raise ValueError(
f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
)
self.attn = JointAttention(
query_dim=dim,
cross_attention_dim=None,
added_kv_proj_dim=dim,
dim_head=attention_head_dim // num_attention_heads,
heads=num_attention_heads,
out_dim=attention_head_dim,
qk_norm=qk_norm,
context_pre_only=context_pre_only,
bias=True,
use_flash_attn=use_flash_attn,
)
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
if not context_pre_only:
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
else:
self.norm2_context = None
self.ff_context = None
def forward(
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor,
encoder_attention_mask: torch.FloatTensor, temb: torch.FloatTensor,
attention_mask: torch.FloatTensor = None, hidden_length: List = None,
image_rotary_emb: torch.FloatTensor = None,
):
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb, hidden_length=hidden_length)
if self.context_pre_only:
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
else:
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
encoder_hidden_states, emb=temb,
)
# Attention
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, attention_mask=attention_mask,
hidden_length=hidden_length, image_rotary_emb=image_rotary_emb,
)
# Process attention outputs for the `hidden_states`.
attn_output = gate_msa * attn_output
hidden_states = hidden_states + attn_output
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
ff_output = self.ff(norm_hidden_states)
ff_output = gate_mlp * ff_output
hidden_states = hidden_states + ff_output
# Process attention outputs for the `encoder_hidden_states`.
if self.context_pre_only:
encoder_hidden_states = None
else:
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
encoder_hidden_states = encoder_hidden_states + context_attn_output
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
context_ff_output = self.ff_context(norm_encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
return encoder_hidden_states, hidden_states
\ No newline at end of file
import numbers
from typing import Dict, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from diffusers.utils import is_torch_version
if is_torch_version(">=", "2.1.0"):
LayerNorm = nn.LayerNorm
else:
# Has optional bias parameter compared to torch layer norm
# TODO: replace with torch layernorm once min required torch version >= 2.1
class LayerNorm(nn.Module):
def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True):
super().__init__()
self.eps = eps
if isinstance(dim, numbers.Integral):
dim = (dim,)
self.dim = torch.Size(dim)
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim))
self.bias = nn.Parameter(torch.zeros(dim)) if bias else None
else:
self.weight = None
self.bias = None
def forward(self, input):
return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps)
class RMSNorm(nn.Module):
def __init__(self, dim, eps: float, elementwise_affine: bool = True):
super().__init__()
self.eps = eps
if isinstance(dim, numbers.Integral):
dim = (dim,)
self.dim = torch.Size(dim)
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim))
else:
self.weight = None
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
if self.weight is not None:
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
hidden_states = hidden_states * self.weight
hidden_states = hidden_states.to(input_dtype)
return hidden_states
class AdaLayerNormContinuous(nn.Module):
def __init__(
self,
embedding_dim: int,
conditioning_embedding_dim: int,
# NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
# because the output is immediately scaled and shifted by the projected conditioning embeddings.
# Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
# However, this is how it was implemented in the original code, and it's rather likely you should
# set `elementwise_affine` to False.
elementwise_affine=True,
eps=1e-5,
bias=True,
norm_type="layer_norm",
):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
if norm_type == "layer_norm":
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
elif norm_type == "rms_norm":
self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
else:
raise ValueError(f"unknown norm_type {norm_type}")
def forward_with_pad(self, x: torch.Tensor, conditioning_embedding: torch.Tensor, hidden_length=None) -> torch.Tensor:
assert hidden_length is not None
emb = self.linear(self.silu(conditioning_embedding).to(x.dtype))
batch_emb = torch.zeros_like(x).repeat(1, 1, 2)
i_sum = 0
num_stages = len(hidden_length)
for i_p, length in enumerate(hidden_length):
batch_emb[:, i_sum:i_sum+length] = emb[i_p::num_stages][:,None]
i_sum += length
batch_scale, batch_shift = torch.chunk(batch_emb, 2, dim=2)
x = self.norm(x) * (1 + batch_scale) + batch_shift
return x
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor, hidden_length=None) -> torch.Tensor:
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
if hidden_length is not None:
return self.forward_with_pad(x, conditioning_embedding, hidden_length)
emb = self.linear(self.silu(conditioning_embedding).to(x.dtype))
scale, shift = torch.chunk(emb, 2, dim=1)
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
return x
class AdaLayerNormZero(nn.Module):
r"""
Norm layer adaptive layer norm zero (adaLN-Zero).
Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_embeddings (`int`): The size of the embeddings dictionary.
"""
def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None):
super().__init__()
self.emb = None
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
def forward_with_pad(
self,
x: torch.Tensor,
timestep: Optional[torch.Tensor] = None,
class_labels: Optional[torch.LongTensor] = None,
hidden_dtype: Optional[torch.dtype] = None,
emb: Optional[torch.Tensor] = None,
hidden_length: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# x: [bs, seq_len, dim]
if self.emb is not None:
emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
emb = self.linear(self.silu(emb))
batch_emb = torch.zeros_like(x).repeat(1, 1, 6)
i_sum = 0
num_stages = len(hidden_length)
for i_p, length in enumerate(hidden_length):
batch_emb[:, i_sum:i_sum+length] = emb[i_p::num_stages][:,None]
i_sum += length
batch_shift_msa, batch_scale_msa, batch_gate_msa, batch_shift_mlp, batch_scale_mlp, batch_gate_mlp = batch_emb.chunk(6, dim=2)
x = self.norm(x) * (1 + batch_scale_msa) + batch_shift_msa
return x, batch_gate_msa, batch_shift_mlp, batch_scale_mlp, batch_gate_mlp
def forward(
self,
x: torch.Tensor,
timestep: Optional[torch.Tensor] = None,
class_labels: Optional[torch.LongTensor] = None,
hidden_dtype: Optional[torch.dtype] = None,
emb: Optional[torch.Tensor] = None,
hidden_length: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
if hidden_length is not None:
return self.forward_with_pad(x, timestep, class_labels, hidden_dtype, emb, hidden_length)
if self.emb is not None:
emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
emb = self.linear(self.silu(emb))
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
\ No newline at end of file
import torch
import torch.nn as nn
import os
import torch.nn.functional as F
from einops import rearrange
from diffusers.utils.torch_utils import randn_tensor
from diffusers.models.modeling_utils import ModelMixin
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils import is_torch_version
from typing import Any, Callable, Dict, List, Optional, Union
from .modeling_embedding import PatchEmbed3D, CombinedTimestepConditionEmbeddings
from .modeling_normalization import AdaLayerNormContinuous
from .modeling_mmdit_block import JointTransformerBlock
from trainer_misc import (
is_sequence_parallel_initialized,
get_sequence_parallel_group,
get_sequence_parallel_world_size,
get_sequence_parallel_rank,
all_to_all,
)
from IPython import embed
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
assert dim % 2 == 0, "The dimension must be even."
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
omega = 1.0 / (theta**scale)
batch_size, seq_length = pos.shape
out = torch.einsum("...n,d->...nd", pos, omega)
cos_out = torch.cos(out)
sin_out = torch.sin(out)
stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
return out.float()
class EmbedNDRoPE(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: List[int]):
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
emb = torch.cat(
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
dim=-3,
)
return emb.unsqueeze(2)
class PyramidDiffusionMMDiT(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
sample_size: int = 128,
patch_size: int = 2,
in_channels: int = 16,
num_layers: int = 24,
attention_head_dim: int = 64,
num_attention_heads: int = 24,
caption_projection_dim: int = 1152,
pooled_projection_dim: int = 2048,
pos_embed_max_size: int = 192,
max_num_frames: int = 200,
qk_norm: str = 'rms_norm',
pos_embed_type: str = 'rope',
temp_pos_embed_type: str = 'sincos',
joint_attention_dim: int = 4096,
use_gradient_checkpointing: bool = False,
use_flash_attn: bool = True,
use_temporal_causal: bool = False,
use_t5_mask: bool = False,
add_temp_pos_embed: bool = False,
interp_condition_pos: bool = False,
gradient_checkpointing_ratio: float = 0.6,
):
super().__init__()
self.out_channels = in_channels
self.inner_dim = num_attention_heads * attention_head_dim
assert temp_pos_embed_type in ['rope', 'sincos']
# The input latent embeder, using the name pos_embed to remain the same with SD#
self.pos_embed = PatchEmbed3D(
height=sample_size,
width=sample_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=self.inner_dim,
pos_embed_max_size=pos_embed_max_size, # hard-code for now.
max_num_frames=max_num_frames,
pos_embed_type=pos_embed_type,
temp_pos_embed_type=temp_pos_embed_type,
add_temp_pos_embed=add_temp_pos_embed,
interp_condition_pos=interp_condition_pos,
)
# The RoPE EMbedding
if pos_embed_type == 'rope':
self.rope_embed = EmbedNDRoPE(self.inner_dim, 10000, axes_dim=[16, 24, 24])
else:
self.rope_embed = None
if temp_pos_embed_type == 'rope':
self.temp_rope_embed = EmbedNDRoPE(self.inner_dim, 10000, axes_dim=[attention_head_dim])
else:
self.temp_rope_embed = None
self.time_text_embed = CombinedTimestepConditionEmbeddings(
embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim,
)
self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.config.caption_projection_dim)
self.transformer_blocks = nn.ModuleList(
[
JointTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=self.inner_dim,
qk_norm=qk_norm,
context_pre_only=i == num_layers - 1,
use_flash_attn=use_flash_attn,
)
for i in range(num_layers)
]
)
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
self.gradient_checkpointing = use_gradient_checkpointing
self.gradient_checkpointing_ratio = gradient_checkpointing_ratio
self.patch_size = patch_size
self.use_flash_attn = use_flash_attn
self.use_temporal_causal = use_temporal_causal
self.pos_embed_type = pos_embed_type
self.temp_pos_embed_type = temp_pos_embed_type
self.add_temp_pos_embed = add_temp_pos_embed
if self.use_temporal_causal:
print("Using temporal causal attention")
assert self.use_flash_attn is False, "The flash attention does not support temporal causal"
if interp_condition_pos:
print("We interp the position embedding of condition latents")
# init weights
self.initialize_weights()
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, (nn.Linear, nn.Conv2d, nn.Conv3d)):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.pos_embed.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.pos_embed.proj.bias, 0)
# Initialize all the conditioning to normal init
nn.init.normal_(self.time_text_embed.timestep_embedder.linear_1.weight, std=0.02)
nn.init.normal_(self.time_text_embed.timestep_embedder.linear_2.weight, std=0.02)
nn.init.normal_(self.time_text_embed.text_embedder.linear_1.weight, std=0.02)
nn.init.normal_(self.time_text_embed.text_embedder.linear_2.weight, std=0.02)
nn.init.normal_(self.context_embedder.weight, std=0.02)
# Zero-out adaLN modulation layers in DiT blocks:
for block in self.transformer_blocks:
nn.init.constant_(block.norm1.linear.weight, 0)
nn.init.constant_(block.norm1.linear.bias, 0)
nn.init.constant_(block.norm1_context.linear.weight, 0)
nn.init.constant_(block.norm1_context.linear.bias, 0)
# Zero-out output layers:
nn.init.constant_(self.norm_out.linear.weight, 0)
nn.init.constant_(self.norm_out.linear.bias, 0)
nn.init.constant_(self.proj_out.weight, 0)
nn.init.constant_(self.proj_out.bias, 0)
@torch.no_grad()
def _prepare_latent_image_ids(self, batch_size, temp, height, width, device):
latent_image_ids = torch.zeros(temp, height, width, 3)
latent_image_ids[..., 0] = latent_image_ids[..., 0] + torch.arange(temp)[:, None, None]
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[None, :, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, None, :]
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1, 1)
latent_image_ids = rearrange(latent_image_ids, 'b t h w c -> b (t h w) c')
return latent_image_ids.to(device=device)
@torch.no_grad()
def _prepare_pyramid_latent_image_ids(self, batch_size, temp_list, height_list, width_list, device):
base_width = width_list[-1]; base_height = height_list[-1]
assert base_width == max(width_list)
assert base_height == max(height_list)
image_ids_list = []
for temp, height, width in zip(temp_list, height_list, width_list):
latent_image_ids = torch.zeros(temp, height, width, 3)
if height != base_height:
height_pos = F.interpolate(torch.arange(base_height)[None, None, :].float(), height, mode='linear').squeeze(0, 1)
else:
height_pos = torch.arange(base_height).float()
if width != base_width:
width_pos = F.interpolate(torch.arange(base_width)[None, None, :].float(), width, mode='linear').squeeze(0, 1)
else:
width_pos = torch.arange(base_width).float()
latent_image_ids[..., 0] = latent_image_ids[..., 0] + torch.arange(temp)[:, None, None]
latent_image_ids[..., 1] = latent_image_ids[..., 1] + height_pos[None, :, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + width_pos[None, None, :]
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1, 1)
latent_image_ids = rearrange(latent_image_ids, 'b t h w c -> b (t h w) c').to(device)
image_ids_list.append(latent_image_ids)
return image_ids_list
@torch.no_grad()
def _prepare_temporal_rope_ids(self, batch_size, temp, height, width, device, start_time_stamp=0):
latent_image_ids = torch.zeros(temp, height, width, 1)
latent_image_ids[..., 0] = latent_image_ids[..., 0] + torch.arange(start_time_stamp, start_time_stamp + temp)[:, None, None]
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1, 1)
latent_image_ids = rearrange(latent_image_ids, 'b t h w c -> b (t h w) c')
return latent_image_ids.to(device=device)
@torch.no_grad()
def _prepare_pyramid_temporal_rope_ids(self, sample, batch_size, device):
image_ids_list = []
for i_b, sample_ in enumerate(sample):
if not isinstance(sample_, list):
sample_ = [sample_]
cur_image_ids = []
start_time_stamp = 0
for clip_ in sample_:
_, _, temp, height, width = clip_.shape
height = height // self.patch_size
width = width // self.patch_size
cur_image_ids.append(self._prepare_temporal_rope_ids(batch_size, temp, height, width, device, start_time_stamp=start_time_stamp))
start_time_stamp += temp
cur_image_ids = torch.cat(cur_image_ids, dim=1)
image_ids_list.append(cur_image_ids)
return image_ids_list
def merge_input(self, sample, encoder_hidden_length, encoder_attention_mask):
"""
Merge the input video with different resolutions into one sequence
Sample: From low resolution to high resolution
"""
if isinstance(sample[0], list):
device = sample[0][-1].device
pad_batch_size = sample[0][-1].shape[0]
else:
device = sample[0].device
pad_batch_size = sample[0].shape[0]
num_stages = len(sample)
height_list = [];width_list = [];temp_list = []
trainable_token_list = []
for i_b, sample_ in enumerate(sample):
if isinstance(sample_, list):
sample_ = sample_[-1]
_, _, temp, height, width = sample_.shape
height = height // self.patch_size
width = width // self.patch_size
temp_list.append(temp)
height_list.append(height)
width_list.append(width)
trainable_token_list.append(height * width * temp)
# prepare the RoPE embedding if needed
if self.pos_embed_type == 'rope':
# TODO: support the 3D Rope for video
raise NotImplementedError("Not compatible with video generation now")
text_ids = torch.zeros(pad_batch_size, encoder_hidden_length, 3).to(device=device)
image_ids_list = self._prepare_pyramid_latent_image_ids(pad_batch_size, temp_list, height_list, width_list, device)
input_ids_list = [torch.cat([text_ids, image_ids], dim=1) for image_ids in image_ids_list]
image_rotary_emb = [self.rope_embed(input_ids) for input_ids in input_ids_list] # [bs, seq_len, 1, head_dim // 2, 2, 2]
else:
if self.temp_pos_embed_type == 'rope' and self.add_temp_pos_embed:
image_ids_list = self._prepare_pyramid_temporal_rope_ids(sample, pad_batch_size, device)
text_ids = torch.zeros(pad_batch_size, encoder_attention_mask.shape[1], 1).to(device=device)
input_ids_list = [torch.cat([text_ids, image_ids], dim=1) for image_ids in image_ids_list]
image_rotary_emb = [self.temp_rope_embed(input_ids) for input_ids in input_ids_list] # [bs, seq_len, 1, head_dim // 2, 2, 2]
if is_sequence_parallel_initialized():
sp_group = get_sequence_parallel_group()
sp_group_size = get_sequence_parallel_world_size()
concat_output = True if self.training else False
image_rotary_emb = [all_to_all(x_.repeat(1, 1, sp_group_size, 1, 1, 1), sp_group, sp_group_size, scatter_dim=2, gather_dim=0, concat_output=concat_output) for x_ in image_rotary_emb]
input_ids_list = [all_to_all(input_ids.repeat(1, 1, sp_group_size), sp_group, sp_group_size, scatter_dim=2, gather_dim=0, concat_output=concat_output) for input_ids in input_ids_list]
else:
image_rotary_emb = None
hidden_states = self.pos_embed(sample) # hidden states is a list of [b c t h w] b = real_b // num_stages
hidden_length = []
for i_b in range(num_stages):
hidden_length.append(hidden_states[i_b].shape[1])
# prepare the attention mask
if self.use_flash_attn:
attention_mask = None
indices_list = []
for i_p, length in enumerate(hidden_length):
pad_attention_mask = torch.ones((pad_batch_size, length), dtype=encoder_attention_mask.dtype).to(device)
pad_attention_mask = torch.cat([encoder_attention_mask[i_p::num_stages], pad_attention_mask], dim=1)
if is_sequence_parallel_initialized():
sp_group = get_sequence_parallel_group()
sp_group_size = get_sequence_parallel_world_size()
pad_attention_mask = all_to_all(pad_attention_mask.unsqueeze(2).repeat(1, 1, sp_group_size), sp_group, sp_group_size, scatter_dim=2, gather_dim=0)
pad_attention_mask = pad_attention_mask.squeeze(2)
seqlens_in_batch = pad_attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(pad_attention_mask.flatten(), as_tuple=False).flatten()
indices_list.append(
{
'indices': indices,
'seqlens_in_batch': seqlens_in_batch,
}
)
encoder_attention_mask = indices_list
else:
assert encoder_attention_mask.shape[1] == encoder_hidden_length
real_batch_size = encoder_attention_mask.shape[0]
# prepare text ids
text_ids = torch.arange(1, real_batch_size + 1, dtype=encoder_attention_mask.dtype).unsqueeze(1).repeat(1, encoder_hidden_length)
text_ids = text_ids.to(device)
text_ids[encoder_attention_mask == 0] = 0
# prepare image ids
image_ids = torch.arange(1, real_batch_size + 1, dtype=encoder_attention_mask.dtype).unsqueeze(1).repeat(1, max(hidden_length))
image_ids = image_ids.to(device)
image_ids_list = []
for i_p, length in enumerate(hidden_length):
image_ids_list.append(image_ids[i_p::num_stages][:, :length])
if is_sequence_parallel_initialized():
sp_group = get_sequence_parallel_group()
sp_group_size = get_sequence_parallel_world_size()
concat_output = True if self.training else False
text_ids = all_to_all(text_ids.unsqueeze(2).repeat(1, 1, sp_group_size), sp_group, sp_group_size, scatter_dim=2, gather_dim=0, concat_output=concat_output).squeeze(2)
image_ids_list = [all_to_all(image_ids_.unsqueeze(2).repeat(1, 1, sp_group_size), sp_group, sp_group_size, scatter_dim=2, gather_dim=0, concat_output=concat_output).squeeze(2) for image_ids_ in image_ids_list]
attention_mask = []
for i_p in range(len(hidden_length)):
image_ids = image_ids_list[i_p]
token_ids = torch.cat([text_ids[i_p::num_stages], image_ids], dim=1)
stage_attention_mask = rearrange(token_ids, 'b i -> b 1 i 1') == rearrange(token_ids, 'b j -> b 1 1 j') # [bs, 1, q_len, k_len]
if self.use_temporal_causal:
input_order_ids = input_ids_list[i_p].squeeze(2)
temporal_causal_mask = rearrange(input_order_ids, 'b i -> b 1 i 1') >= rearrange(input_order_ids, 'b j -> b 1 1 j')
stage_attention_mask = stage_attention_mask & temporal_causal_mask
attention_mask.append(stage_attention_mask)
return hidden_states, hidden_length, temp_list, height_list, width_list, trainable_token_list, encoder_attention_mask, attention_mask, image_rotary_emb
def split_output(self, batch_hidden_states, hidden_length, temps, heights, widths, trainable_token_list):
# To split the hidden states
batch_size = batch_hidden_states.shape[0]
output_hidden_list = []
batch_hidden_states = torch.split(batch_hidden_states, hidden_length, dim=1)
if is_sequence_parallel_initialized():
sp_group_size = get_sequence_parallel_world_size()
if self.training:
batch_size = batch_size // sp_group_size
for i_p, length in enumerate(hidden_length):
width, height, temp = widths[i_p], heights[i_p], temps[i_p]
trainable_token_num = trainable_token_list[i_p]
hidden_states = batch_hidden_states[i_p]
if is_sequence_parallel_initialized():
sp_group = get_sequence_parallel_group()
sp_group_size = get_sequence_parallel_world_size()
if not self.training:
hidden_states = hidden_states.repeat(sp_group_size, 1, 1)
hidden_states = all_to_all(hidden_states, sp_group, sp_group_size, scatter_dim=0, gather_dim=1)
# only the trainable token are taking part in loss computation
hidden_states = hidden_states[:, -trainable_token_num:]
# unpatchify
hidden_states = hidden_states.reshape(
shape=(batch_size, temp, height, width, self.patch_size, self.patch_size, self.out_channels)
)
hidden_states = rearrange(hidden_states, "b t h w p1 p2 c -> b t (h p1) (w p2) c")
hidden_states = rearrange(hidden_states, "b t h w c -> b c t h w")
output_hidden_list.append(hidden_states)
return output_hidden_list
def forward(
self,
sample: torch.FloatTensor, # [num_stages]
encoder_hidden_states: torch.FloatTensor = None,
encoder_attention_mask: torch.FloatTensor = None,
pooled_projections: torch.FloatTensor = None,
timestep_ratio: torch.FloatTensor = None,
):
# Get the timestep embedding
temb = self.time_text_embed(timestep_ratio, pooled_projections)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
encoder_hidden_length = encoder_hidden_states.shape[1]
# Get the input sequence
hidden_states, hidden_length, temps, heights, widths, trainable_token_list, encoder_attention_mask, \
attention_mask, image_rotary_emb = self.merge_input(sample, encoder_hidden_length, encoder_attention_mask)
# split the long latents if necessary
if is_sequence_parallel_initialized():
sp_group = get_sequence_parallel_group()
sp_group_size = get_sequence_parallel_world_size()
concat_output = True if self.training else False
# sync the input hidden states
batch_hidden_states = []
for i_p, hidden_states_ in enumerate(hidden_states):
assert hidden_states_.shape[1] % sp_group_size == 0, "The sequence length should be divided by sequence parallel size"
hidden_states_ = all_to_all(hidden_states_, sp_group, sp_group_size, scatter_dim=1, gather_dim=0, concat_output=concat_output)
hidden_length[i_p] = hidden_length[i_p] // sp_group_size
batch_hidden_states.append(hidden_states_)
# sync the encoder hidden states
hidden_states = torch.cat(batch_hidden_states, dim=1)
encoder_hidden_states = all_to_all(encoder_hidden_states, sp_group, sp_group_size, scatter_dim=1, gather_dim=0, concat_output=concat_output)
temb = all_to_all(temb.unsqueeze(1).repeat(1, sp_group_size, 1), sp_group, sp_group_size, scatter_dim=1, gather_dim=0, concat_output=concat_output)
temb = temb.squeeze(1)
else:
hidden_states = torch.cat(hidden_states, dim=1)
# print(hidden_length)
for i_b, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing and (i_b >= int(len(self.transformer_blocks) * self.gradient_checkpointing_ratio)):
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
encoder_attention_mask,
temb,
attention_mask,
hidden_length,
image_rotary_emb,
**ckpt_kwargs,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
temb=temb,
attention_mask=attention_mask,
hidden_length=hidden_length,
image_rotary_emb=image_rotary_emb,
)
hidden_states = self.norm_out(hidden_states, temb, hidden_length=hidden_length)
hidden_states = self.proj_out(hidden_states)
output = self.split_output(hidden_states, hidden_length, temps, heights, widths, trainable_token_list)
return output
import torch
import torch.nn as nn
import os
from transformers import (
CLIPTextModelWithProjection,
CLIPTokenizer,
T5EncoderModel,
T5TokenizerFast,
)
from typing import Any, Callable, Dict, List, Optional, Union
class SD3TextEncoderWithMask(nn.Module):
def __init__(self, model_path, torch_dtype):
super().__init__()
# CLIP-L
self.tokenizer = CLIPTokenizer.from_pretrained(os.path.join(model_path, 'tokenizer'))
self.tokenizer_max_length = self.tokenizer.model_max_length
self.text_encoder = CLIPTextModelWithProjection.from_pretrained(os.path.join(model_path, 'text_encoder'), torch_dtype=torch_dtype)
# CLIP-G
self.tokenizer_2 = CLIPTokenizer.from_pretrained(os.path.join(model_path, 'tokenizer_2'))
self.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(os.path.join(model_path, 'text_encoder_2'), torch_dtype=torch_dtype)
# T5
self.tokenizer_3 = T5TokenizerFast.from_pretrained(os.path.join(model_path, 'tokenizer_3'))
self.text_encoder_3 = T5EncoderModel.from_pretrained(os.path.join(model_path, 'text_encoder_3'), torch_dtype=torch_dtype)
self._freeze()
def _freeze(self):
for param in self.parameters():
param.requires_grad = False
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
num_images_per_prompt: int = 1,
device: Optional[torch.device] = None,
max_sequence_length: int = 128,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
text_inputs = self.tokenizer_3(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_attention_mask = text_inputs.attention_mask
prompt_attention_mask = prompt_attention_mask.to(device)
prompt_embeds = self.text_encoder_3(text_input_ids.to(device), attention_mask=prompt_attention_mask)[0]
dtype = self.text_encoder_3.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
_, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
return prompt_embeds, prompt_attention_mask
def _get_clip_prompt_embeds(
self,
prompt: Union[str, List[str]],
num_images_per_prompt: int = 1,
device: Optional[torch.device] = None,
clip_skip: Optional[int] = None,
clip_model_index: int = 0,
):
clip_tokenizers = [self.tokenizer, self.tokenizer_2]
clip_text_encoders = [self.text_encoder, self.text_encoder_2]
tokenizer = clip_tokenizers[clip_model_index]
text_encoder = clip_text_encoders[clip_model_index]
batch_size = len(prompt)
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
pooled_prompt_embeds = prompt_embeds[0]
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return pooled_prompt_embeds
def encode_prompt(self,
prompt,
num_images_per_prompt=1,
clip_skip: Optional[int] = None,
device=None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
pooled_prompt_embed = self._get_clip_prompt_embeds(
prompt=prompt,
device=device,
num_images_per_prompt=num_images_per_prompt,
clip_skip=clip_skip,
clip_model_index=0,
)
pooled_prompt_2_embed = self._get_clip_prompt_embeds(
prompt=prompt,
device=device,
num_images_per_prompt=num_images_per_prompt,
clip_skip=clip_skip,
clip_model_index=1,
)
pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
)
return prompt_embeds, prompt_attention_mask, pooled_prompt_embeds
def forward(self, input_prompts, device):
with torch.no_grad():
prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = self.encode_prompt(input_prompts, 1, clip_skip=None, device=device)
return prompt_embeds, prompt_attention_mask, pooled_prompt_embeds
\ No newline at end of file
import torch
import os
import gc
import sys
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from einops import rearrange
from diffusers.utils.torch_utils import randn_tensor
import numpy as np
import math
import random
import PIL
from PIL import Image
from tqdm import tqdm
from torchvision import transforms
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Union
from accelerate import Accelerator, cpu_offload
from diffusion_schedulers import PyramidFlowMatchEulerDiscreteScheduler
from video_vae.modeling_causal_vae import CausalVideoVAE
import time # TODO
from trainer_misc import (
all_to_all,
is_sequence_parallel_initialized,
get_sequence_parallel_group,
get_sequence_parallel_group_rank,
get_sequence_parallel_rank,
get_sequence_parallel_world_size,
get_rank,
)
from .mmdit_modules import (
PyramidDiffusionMMDiT,
SD3TextEncoderWithMask,
)
from .flux_modules import (
PyramidFluxTransformer,
FluxTextEncoderWithMask,
)
def compute_density_for_timestep_sampling(
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
):
if weighting_scheme == "logit_normal":
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
u = torch.nn.functional.sigmoid(u)
elif weighting_scheme == "mode":
u = torch.rand(size=(batch_size,), device="cpu")
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
else:
u = torch.rand(size=(batch_size,), device="cpu")
return u
def build_pyramid_dit(
model_name : str,
model_path : str,
torch_dtype,
use_flash_attn : bool,
use_mixed_training: bool,
interp_condition_pos: bool = True,
use_gradient_checkpointing: bool = False,
use_temporal_causal: bool = True,
gradient_checkpointing_ratio: float = 0.6,
):
model_dtype = torch.float32 if use_mixed_training else torch_dtype
if model_name == "pyramid_flux":
dit = PyramidFluxTransformer.from_pretrained(
model_path, torch_dtype=model_dtype,
use_gradient_checkpointing=use_gradient_checkpointing,
gradient_checkpointing_ratio=gradient_checkpointing_ratio,
use_flash_attn=use_flash_attn, use_temporal_causal=use_temporal_causal,
interp_condition_pos=interp_condition_pos, axes_dims_rope=[16, 24, 24],
)
elif model_name == "pyramid_mmdit":
dit = PyramidDiffusionMMDiT.from_pretrained(
model_path, torch_dtype=model_dtype, use_gradient_checkpointing=use_gradient_checkpointing,
gradient_checkpointing_ratio=gradient_checkpointing_ratio,
use_flash_attn=use_flash_attn, use_t5_mask=True,
add_temp_pos_embed=True, temp_pos_embed_type='rope',
use_temporal_causal=use_temporal_causal, interp_condition_pos=interp_condition_pos,
)
else:
raise NotImplementedError(f"Unsupported DiT architecture, please set the model_name to `pyramid_flux` or `pyramid_mmdit`")
return dit
def build_text_encoder(
model_name : str,
model_path : str,
torch_dtype,
load_text_encoder: bool = True,
):
# The text encoder
if load_text_encoder:
if model_name == "pyramid_flux":
text_encoder = FluxTextEncoderWithMask(model_path, torch_dtype=torch_dtype)
elif model_name == "pyramid_mmdit":
text_encoder = SD3TextEncoderWithMask(model_path, torch_dtype=torch_dtype)
else:
raise NotImplementedError(f"Unsupported Text Encoder architecture, please set the model_name to `pyramid_flux` or `pyramid_mmdit`")
else:
text_encoder = None
return text_encoder
class PyramidDiTForVideoGeneration:
"""
The pyramid dit for both image and video generation, The running class wrapper
This class is mainly for fixed unit implementation: 1 + n + n + n
"""
def __init__(self, model_path, model_dtype='bf16', model_name='pyramid_mmdit', use_gradient_checkpointing=False,
return_log=True, model_variant="diffusion_transformer_768p", timestep_shift=1.0, stage_range=[0, 1/3, 2/3, 1],
sample_ratios=[1, 1, 1], scheduler_gamma=1/3, use_mixed_training=False, use_flash_attn=False,
load_text_encoder=True, load_vae=True, max_temporal_length=31, frame_per_unit=1, use_temporal_causal=True,
corrupt_ratio=1/3, interp_condition_pos=True, stages=[1, 2, 4], video_sync_group=8, gradient_checkpointing_ratio=0.6, **kwargs,
):
super().__init__()
if model_dtype == 'bf16':
torch_dtype = torch.bfloat16
elif model_dtype == 'fp16':
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
self.stages = stages
self.sample_ratios = sample_ratios
self.corrupt_ratio = corrupt_ratio
dit_path = os.path.join(model_path, model_variant)
# The dit
self.dit = build_pyramid_dit(
model_name, dit_path, torch_dtype,
use_flash_attn=use_flash_attn, use_mixed_training=use_mixed_training,
interp_condition_pos=interp_condition_pos, use_gradient_checkpointing=use_gradient_checkpointing,
use_temporal_causal=use_temporal_causal, gradient_checkpointing_ratio=gradient_checkpointing_ratio,
)
# The text encoder
self.text_encoder = build_text_encoder(
model_name, model_path, torch_dtype, load_text_encoder=load_text_encoder,
)
self.load_text_encoder = load_text_encoder
# The base video vae decoder
if load_vae:
self.vae = CausalVideoVAE.from_pretrained(os.path.join(model_path, 'causal_video_vae'), torch_dtype=torch_dtype, interpolate=False)
# Freeze vae
for parameter in self.vae.parameters():
parameter.requires_grad = False
else:
self.vae = None
self.load_vae = load_vae
# For the image latent
if model_name == "pyramid_flux":
self.vae_shift_factor = -0.04
self.vae_scale_factor = 1 / 1.8726
elif model_name == "pyramid_mmdit":
self.vae_shift_factor = 0.1490
self.vae_scale_factor = 1 / 1.8415
else:
raise NotImplementedError(f"Unsupported model name : {model_name}")
# For the video latent
self.vae_video_shift_factor = -0.2343
self.vae_video_scale_factor = 1 / 3.0986
self.downsample = 8
# Configure the video training hyper-parameters
# The video sequence: one frame + N * unit
self.frame_per_unit = frame_per_unit
self.max_temporal_length = max_temporal_length
assert (max_temporal_length - 1) % frame_per_unit == 0, "The frame number should be divided by the frame number per unit"
self.num_units_per_video = 1 + ((max_temporal_length - 1) // frame_per_unit) + int(sum(sample_ratios))
self.scheduler = PyramidFlowMatchEulerDiscreteScheduler(
shift=timestep_shift, stages=len(self.stages),
stage_range=stage_range, gamma=scheduler_gamma,
)
print(f"The start sigmas and end sigmas of each stage is Start: {self.scheduler.start_sigmas}, End: {self.scheduler.end_sigmas}, Ori_start: {self.scheduler.ori_start_sigmas}")
self.cfg_rate = 0.1
self.return_log = return_log
self.use_flash_attn = use_flash_attn
self.model_name = model_name
self.sequential_offload_enabled = False
self.accumulate_steps = 0
self.video_sync_group = video_sync_group
def _enable_sequential_cpu_offload(self, model):
self.sequential_offload_enabled = True
torch_device = torch.device("cuda")
device_type = torch_device.type
device = torch.device(f"{device_type}:0")
offload_buffers = len(model._parameters) > 0
cpu_offload(model, device, offload_buffers=offload_buffers)
def enable_sequential_cpu_offload(self):
self._enable_sequential_cpu_offload(self.text_encoder)
self._enable_sequential_cpu_offload(self.dit)
def load_checkpoint(self, checkpoint_path, model_key='model', **kwargs):
checkpoint = torch.load(checkpoint_path, map_location='cpu')
dit_checkpoint = OrderedDict()
for key in checkpoint:
if key.startswith('vae') or key.startswith('text_encoder'):
continue
if key.startswith('dit'):
new_key = key.split('.')
new_key = '.'.join(new_key[1:])
dit_checkpoint[new_key] = checkpoint[key]
else:
dit_checkpoint[key] = checkpoint[key]
load_result = self.dit.load_state_dict(dit_checkpoint, strict=True)
print(f"Load checkpoint from {checkpoint_path}, load result: {load_result}")
def load_vae_checkpoint(self, vae_checkpoint_path, model_key='model'):
checkpoint = torch.load(vae_checkpoint_path, map_location='cpu')
checkpoint = checkpoint[model_key]
loaded_checkpoint = OrderedDict()
for key in checkpoint.keys():
if key.startswith('vae.'):
new_key = key.split('.')
new_key = '.'.join(new_key[1:])
loaded_checkpoint[new_key] = checkpoint[key]
load_result = self.vae.load_state_dict(loaded_checkpoint)
print(f"Load the VAE from {vae_checkpoint_path}, load result: {load_result}")
@torch.no_grad()
def add_pyramid_noise(
self,
latents_list,
sample_ratios=[1, 1, 1],
):
"""
add the noise for each pyramidal stage
noting that, this method is a general strategy for pyramid-flow, it
can be used for both image and video training.
You can also use this method to train pyramid-flow with full-sequence
diffusion in video generation (without using temporal pyramid and autoregressive modeling)
Params:
latent_list: [low_res, mid_res, high_res] The vae latents of all stages
sample_ratios: The proportion of each stage in the training batch
"""
noise = torch.randn_like(latents_list[-1])
device = noise.device
dtype = latents_list[-1].dtype
t = noise.shape[2]
stages = len(self.stages)
tot_samples = noise.shape[0]
assert tot_samples % (int(sum(sample_ratios))) == 0
assert stages == len(sample_ratios)
height, width = noise.shape[-2], noise.shape[-1]
noise_list = [noise]
cur_noise = noise
for i_s in range(stages-1):
height //= 2;width //= 2
cur_noise = rearrange(cur_noise, 'b c t h w -> (b t) c h w')
cur_noise = F.interpolate(cur_noise, size=(height, width), mode='bilinear') * 2
cur_noise = rearrange(cur_noise, '(b t) c h w -> b c t h w', t=t)
noise_list.append(cur_noise)
noise_list = list(reversed(noise_list)) # make sure from low res to high res
# To calculate the padding batchsize and column size
batch_size = tot_samples // int(sum(sample_ratios))
column_size = int(sum(sample_ratios))
column_to_stage = {}
i_sum = 0
for i_s, column_num in enumerate(sample_ratios):
for index in range(i_sum, i_sum + column_num):
column_to_stage[index] = i_s
i_sum += column_num
noisy_latents_list = []
ratios_list = []
targets_list = []
timesteps_list = []
training_steps = self.scheduler.config.num_train_timesteps
# from low resolution to high resolution
for index in range(column_size):
i_s = column_to_stage[index]
clean_latent = latents_list[i_s][index::column_size] # [bs, c, t, h, w]
last_clean_latent = None if i_s == 0 else latents_list[i_s-1][index::column_size]
start_sigma = self.scheduler.start_sigmas[i_s]
end_sigma = self.scheduler.end_sigmas[i_s]
if i_s == 0:
start_point = noise_list[i_s][index::column_size]
else:
# Get the upsampled latent
last_clean_latent = rearrange(last_clean_latent, 'b c t h w -> (b t) c h w')
last_clean_latent = F.interpolate(last_clean_latent, size=(last_clean_latent.shape[-2] * 2, last_clean_latent.shape[-1] * 2), mode='nearest')
last_clean_latent = rearrange(last_clean_latent, '(b t) c h w -> b c t h w', t=t)
start_point = start_sigma * noise_list[i_s][index::column_size] + (1 - start_sigma) * last_clean_latent
if i_s == stages - 1:
end_point = clean_latent
else:
end_point = end_sigma * noise_list[i_s][index::column_size] + (1 - end_sigma) * clean_latent
# To sample a timestep
u = compute_density_for_timestep_sampling(
weighting_scheme='random',
batch_size=batch_size,
logit_mean=0.0,
logit_std=1.0,
mode_scale=1.29,
)
indices = (u * training_steps).long() # Totally 1000 training steps per stage
indices = indices.clamp(0, training_steps-1)
timesteps = self.scheduler.timesteps_per_stage[i_s][indices].to(device=device)
ratios = self.scheduler.sigmas_per_stage[i_s][indices].to(device=device)
while len(ratios.shape) < start_point.ndim:
ratios = ratios.unsqueeze(-1)
# interpolate the latent
noisy_latents = ratios * start_point + (1 - ratios) * end_point
last_cond_noisy_sigma = torch.rand(size=(batch_size,), device=device) * self.corrupt_ratio
# [stage1_latent, stage2_latent, ..., stagen_latent], which will be concat after patching
noisy_latents_list.append([noisy_latents.to(dtype)])
ratios_list.append(ratios.to(dtype))
timesteps_list.append(timesteps.to(dtype))
targets_list.append(start_point - end_point) # The standard rectified flow matching objective
return noisy_latents_list, ratios_list, timesteps_list, targets_list
def sample_stage_length(self, num_stages, max_units=None):
max_units_in_training = 1 + ((self.max_temporal_length - 1) // self.frame_per_unit)
cur_rank = get_rank()
self.accumulate_steps = self.accumulate_steps + 1
total_turns = max_units_in_training // self.video_sync_group
update_turn = self.accumulate_steps % total_turns
# # uniformly sampling each position
cur_highres_unit = max(int((cur_rank % self.video_sync_group + 1) + update_turn * self.video_sync_group), 1)
cur_mid_res_unit = max(1 + max_units_in_training - cur_highres_unit, 1)
cur_low_res_unit = cur_mid_res_unit
if max_units is not None:
cur_highres_unit = min(cur_highres_unit, max_units)
cur_mid_res_unit = min(cur_mid_res_unit, max_units)
cur_low_res_unit = min(cur_low_res_unit, max_units)
length_list = [cur_low_res_unit, cur_mid_res_unit, cur_highres_unit]
assert len(length_list) == num_stages
return length_list
@torch.no_grad()
def add_pyramid_noise_with_temporal_pyramid(
self,
latents_list,
sample_ratios=[1, 1, 1],
):
"""
add the noise for each pyramidal stage, used for AR video training with temporal pyramid
Params:
latent_list: [low_res, mid_res, high_res] The vae latents of all stages
sample_ratios: The proportion of each stage in the training batch
"""
stages = len(self.stages)
tot_samples = latents_list[0].shape[0]
device = latents_list[0].device
dtype = latents_list[0].dtype
assert tot_samples % (int(sum(sample_ratios))) == 0
assert stages == len(sample_ratios)
noise = torch.randn_like(latents_list[-1])
t = noise.shape[2]
# To allocate the temporal length of each stage, ensuring the sum == constant
max_units = 1 + (t - 1) // self.frame_per_unit
if is_sequence_parallel_initialized():
max_units_per_sample = torch.LongTensor([max_units]).to(device)
sp_group = get_sequence_parallel_group()
sp_group_size = get_sequence_parallel_world_size()
max_units_per_sample = all_to_all(max_units_per_sample.unsqueeze(1).repeat(1, sp_group_size), sp_group, sp_group_size, scatter_dim=1, gather_dim=0).squeeze(1)
max_units = min(max_units_per_sample.cpu().tolist())
num_units_per_stage = self.sample_stage_length(stages, max_units=max_units) # [The unit number of each stage]
# we needs to sync the length alloc of each sequence parallel group
if is_sequence_parallel_initialized():
num_units_per_stage = torch.LongTensor(num_units_per_stage).to(device)
sp_group_rank = get_sequence_parallel_group_rank()
global_src_rank = sp_group_rank * get_sequence_parallel_world_size()
torch.distributed.broadcast(num_units_per_stage, global_src_rank, group=get_sequence_parallel_group())
num_units_per_stage = num_units_per_stage.tolist()
height, width = noise.shape[-2], noise.shape[-1]
noise_list = [noise]
cur_noise = noise
for i_s in range(stages-1):
height //= 2;width //= 2
cur_noise = rearrange(cur_noise, 'b c t h w -> (b t) c h w')
cur_noise = F.interpolate(cur_noise, size=(height, width), mode='bilinear') * 2
cur_noise = rearrange(cur_noise, '(b t) c h w -> b c t h w', t=t)
noise_list.append(cur_noise)
noise_list = list(reversed(noise_list)) # make sure from low res to high res
# To calculate the batchsize and column size
batch_size = tot_samples // int(sum(sample_ratios))
column_size = int(sum(sample_ratios))
column_to_stage = {}
i_sum = 0
for i_s, column_num in enumerate(sample_ratios):
for index in range(i_sum, i_sum + column_num):
column_to_stage[index] = i_s
i_sum += column_num
noisy_latents_list = []
ratios_list = []
targets_list = []
timesteps_list = []
training_steps = self.scheduler.config.num_train_timesteps
# from low resolution to high resolution
for index in range(column_size):
# First prepare the trainable latent construction
i_s = column_to_stage[index]
clean_latent = latents_list[i_s][index::column_size] # [bs, c, t, h, w]
last_clean_latent = None if i_s == 0 else latents_list[i_s-1][index::column_size]
start_sigma = self.scheduler.start_sigmas[i_s]
end_sigma = self.scheduler.end_sigmas[i_s]
if i_s == 0:
start_point = noise_list[i_s][index::column_size]
else:
# Get the upsampled latent
last_clean_latent = rearrange(last_clean_latent, 'b c t h w -> (b t) c h w')
last_clean_latent = F.interpolate(last_clean_latent, size=(last_clean_latent.shape[-2] * 2, last_clean_latent.shape[-1] * 2), mode='nearest')
last_clean_latent = rearrange(last_clean_latent, '(b t) c h w -> b c t h w', t=t)
start_point = start_sigma * noise_list[i_s][index::column_size] + (1 - start_sigma) * last_clean_latent
if i_s == stages - 1:
end_point = clean_latent
else:
end_point = end_sigma * noise_list[i_s][index::column_size] + (1 - end_sigma) * clean_latent
# To sample a timestep
u = compute_density_for_timestep_sampling(
weighting_scheme='random',
batch_size=batch_size,
logit_mean=0.0,
logit_std=1.0,
mode_scale=1.29,
)
indices = (u * training_steps).long() # Totally 1000 training steps per stage
indices = indices.clamp(0, training_steps-1)
timesteps = self.scheduler.timesteps_per_stage[i_s][indices].to(device=device)
ratios = self.scheduler.sigmas_per_stage[i_s][indices].to(device=device)
noise_ratios = ratios * start_sigma + (1 - ratios) * end_sigma
while len(ratios.shape) < start_point.ndim:
ratios = ratios.unsqueeze(-1)
# interpolate the latent
noisy_latents = ratios * start_point + (1 - ratios) * end_point
# The flow matching object
target_latents = start_point - end_point
# pad the noisy previous
num_units = num_units_per_stage[i_s]
num_units = min(num_units, 1 + (t - 1) // self.frame_per_unit)
actual_frames = 1 + (num_units - 1) * self.frame_per_unit
noisy_latents = noisy_latents[:, :, :actual_frames]
target_latents = target_latents[:, :, :actual_frames]
clean_latent = clean_latent[:, :, :actual_frames]
stage_noise = noise_list[i_s][index::column_size][:, :, :actual_frames]
# only the last latent takes part in training
noisy_latents = noisy_latents[:, :, -self.frame_per_unit:]
target_latents = target_latents[:, :, -self.frame_per_unit:]
last_cond_noisy_sigma = torch.rand(size=(batch_size,), device=device) * self.corrupt_ratio
if num_units == 1:
stage_input = [noisy_latents.to(dtype)]
else:
# add the random noise for the last cond clip
last_cond_latent = clean_latent[:, :, -(2*self.frame_per_unit):-self.frame_per_unit]
while len(last_cond_noisy_sigma.shape) < last_cond_latent.ndim:
last_cond_noisy_sigma = last_cond_noisy_sigma.unsqueeze(-1)
# We adding some noise to corrupt the clean condition
last_cond_latent = last_cond_noisy_sigma * torch.randn_like(last_cond_latent) + (1 - last_cond_noisy_sigma) * last_cond_latent
# concat the corrupted condition and the input noisy latents
stage_input = [noisy_latents.to(dtype), last_cond_latent.to(dtype)]
cur_unit_num = 2
cur_stage = i_s
while cur_unit_num < num_units:
cur_stage = max(cur_stage - 1, 0)
if cur_stage == 0:
break
cur_unit_num += 1
cond_latents = latents_list[cur_stage][index::column_size][:, :, :actual_frames]
cond_latents = cond_latents[:, :, -(cur_unit_num * self.frame_per_unit) : -((cur_unit_num - 1) * self.frame_per_unit)]
cond_latents = last_cond_noisy_sigma * torch.randn_like(cond_latents) + (1 - last_cond_noisy_sigma) * cond_latents
stage_input.append(cond_latents.to(dtype))
if cur_stage == 0 and cur_unit_num < num_units:
cond_latents = latents_list[0][index::column_size][:, :, :actual_frames]
cond_latents = cond_latents[:, :, :-(cur_unit_num * self.frame_per_unit)]
cond_latents = last_cond_noisy_sigma * torch.randn_like(cond_latents) + (1 - last_cond_noisy_sigma) * cond_latents
stage_input.append(cond_latents.to(dtype))
stage_input = list(reversed(stage_input))
noisy_latents_list.append(stage_input)
ratios_list.append(ratios.to(dtype))
timesteps_list.append(timesteps.to(dtype))
targets_list.append(target_latents) # The standard rectified flow matching objective
return noisy_latents_list, ratios_list, timesteps_list, targets_list
@torch.no_grad()
def get_pyramid_latent(self, x, stage_num):
# x is the origin vae latent
vae_latent_list = []
vae_latent_list.append(x)
temp, height, width = x.shape[-3], x.shape[-2], x.shape[-1]
for _ in range(stage_num):
height //= 2
width //= 2
x = rearrange(x, 'b c t h w -> (b t) c h w')
x = torch.nn.functional.interpolate(x, size=(height, width), mode='bilinear')
x = rearrange(x, '(b t) c h w -> b c t h w', t=temp)
vae_latent_list.append(x)
vae_latent_list = list(reversed(vae_latent_list))
return vae_latent_list
@torch.no_grad()
def get_vae_latent(self, video, use_temporal_pyramid=True):
if self.load_vae:
assert video.shape[1] == 3, "The vae is loaded, the input should be raw pixels"
video = self.vae.encode(video).latent_dist.sample() # [b c t h w]
if video.shape[2] == 1:
# is image
video = (video - self.vae_shift_factor) * self.vae_scale_factor
else:
# is video
video[:, :, :1] = (video[:, :, :1] - self.vae_shift_factor) * self.vae_scale_factor
video[:, :, 1:] = (video[:, :, 1:] - self.vae_video_shift_factor) * self.vae_video_scale_factor
# Get the pyramidal stages
vae_latent_list = self.get_pyramid_latent(video, len(self.stages) - 1)
if use_temporal_pyramid:
noisy_latents_list, ratios_list, timesteps_list, targets_list = self.add_pyramid_noise_with_temporal_pyramid(vae_latent_list, self.sample_ratios)
else:
# Only use the spatial pyramidal (without temporal ar)
noisy_latents_list, ratios_list, timesteps_list, targets_list = self.add_pyramid_noise(vae_latent_list, self.sample_ratios)
return noisy_latents_list, ratios_list, timesteps_list, targets_list
@torch.no_grad()
def get_text_embeddings(self, text, rand_idx, device):
if self.load_text_encoder:
batch_size = len(text) # Text is a str list
for idx in range(batch_size):
if rand_idx[idx].item():
text[idx] = ''
return self.text_encoder(text, device) # [b s c]
else:
batch_size = len(text['prompt_embeds'])
for idx in range(batch_size):
if rand_idx[idx].item():
text['prompt_embeds'][idx] = self.null_text_embeds['prompt_embed'].to(device)
text['prompt_attention_mask'][idx] = self.null_text_embeds['prompt_attention_mask'].to(device)
text['pooled_prompt_embeds'][idx] = self.null_text_embeds['pooled_prompt_embed'].to(device)
return text['prompt_embeds'], text['prompt_attention_mask'], text['pooled_prompt_embeds']
def calculate_loss(self, model_preds_list, targets_list):
loss_list = []
for model_pred, target in zip(model_preds_list, targets_list):
# Compute the loss.
loss_weight = torch.ones_like(target)
loss = torch.mean(
(loss_weight.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
1,
)
loss_list.append(loss)
diffusion_loss = torch.cat(loss_list, dim=0).mean()
if self.return_log:
log = {}
split="train"
log[f'{split}/loss'] = diffusion_loss.detach()
return diffusion_loss, log
else:
return diffusion_loss, {}
def __call__(self, video, text, identifier=['video'], use_temporal_pyramid=True, accelerator: Accelerator=None):
xdim = video.ndim
device = video.device
if 'video' in identifier:
assert 'image' not in identifier
is_image = False
else:
assert 'video' not in identifier
video = video.unsqueeze(2) # 'b c h w -> b c 1 h w'
is_image = True
# TODO: now have 3 stages, firstly get the vae latents
with torch.no_grad(), accelerator.autocast():
# 10% prob drop the text
batch_size = len(video)
rand_idx = torch.rand((batch_size,)) <= self.cfg_rate
prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = self.get_text_embeddings(text, rand_idx, device)
noisy_latents_list, ratios_list, timesteps_list, targets_list = self.get_vae_latent(video, use_temporal_pyramid=use_temporal_pyramid)
timesteps = torch.cat([timestep.unsqueeze(-1) for timestep in timesteps_list], dim=-1)
timesteps = timesteps.reshape(-1)
assert timesteps.shape[0] == prompt_embeds.shape[0]
# DiT forward
model_preds_list = self.dit(
sample=noisy_latents_list,
timestep_ratio=timesteps,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask,
pooled_projections=pooled_prompt_embeds,
)
# calculate the loss
return self.calculate_loss(model_preds_list, targets_list)
def prepare_latents(
self,
batch_size,
num_channels_latents,
temp,
height,
width,
dtype,
device,
generator,
):
shape = (
batch_size,
num_channels_latents,
int(temp),
int(height) // self.downsample,
int(width) // self.downsample,
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
return latents
def sample_block_noise(self, bs, ch, temp, height, width):
gamma = self.scheduler.config.gamma
dist = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(4), torch.eye(4) * (1 + gamma) - torch.ones(4, 4) * gamma)
block_number = bs * ch * temp * (height // 2) * (width // 2)
noise = torch.stack([dist.sample() for _ in range(block_number)]) # [block number, 4]
noise = rearrange(noise, '(b c t h w) (p q) -> b c t (h p) (w q)',b=bs,c=ch,t=temp,h=height//2,w=width//2,p=2,q=2)
return noise
@torch.no_grad()
def generate_one_unit(
self,
latents,
past_conditions, # List of past conditions, contains the conditions of each stage
prompt_embeds,
prompt_attention_mask,
pooled_prompt_embeds,
num_inference_steps,
height,
width,
temp,
device,
dtype,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
is_first_frame: bool = False,
):
stages = self.stages
intermed_latents = []
for i_s in range(len(stages)):
self.scheduler.set_timesteps(num_inference_steps[i_s], i_s, device=device)
timesteps = self.scheduler.timesteps
if i_s > 0:
height *= 2; width *= 2
latents = rearrange(latents, 'b c t h w -> (b t) c h w')
latents = F.interpolate(latents, size=(height, width), mode='nearest')
latents = rearrange(latents, '(b t) c h w -> b c t h w', t=temp)
# Fix the stage
ori_sigma = 1 - self.scheduler.ori_start_sigmas[i_s] # the original coeff of signal
gamma = self.scheduler.config.gamma
alpha = 1 / (math.sqrt(1 + (1 / gamma)) * (1 - ori_sigma) + ori_sigma)
beta = alpha * (1 - ori_sigma) / math.sqrt(gamma)
bs, ch, temp, height, width = latents.shape
noise = self.sample_block_noise(bs, ch, temp, height, width)
noise = noise.to(device=device, dtype=dtype)
latents = alpha * latents + beta * noise # To fix the block artifact
for idx, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
if is_sequence_parallel_initialized():
# sync the input latent
sp_group_rank = get_sequence_parallel_group_rank()
global_src_rank = sp_group_rank * get_sequence_parallel_world_size()
torch.distributed.broadcast(latent_model_input, global_src_rank, group=get_sequence_parallel_group())
latent_model_input = past_conditions[i_s] + [latent_model_input]
noise_pred = self.dit(
sample=[latent_model_input],
timestep_ratio=timestep,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask,
pooled_projections=pooled_prompt_embeds,
)
noise_pred = noise_pred[0]
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
if is_first_frame:
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
else:
noise_pred = noise_pred_uncond + self.video_guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(
model_output=noise_pred,
timestep=timestep,
sample=latents,
generator=generator,
).prev_sample
intermed_latents.append(latents)
return intermed_latents
@torch.no_grad()
def generate_i2v(
self,
prompt: Union[str, List[str]] = '',
input_image: PIL.Image = None,
temp: int = 1,
num_inference_steps: Optional[Union[int, List[int]]] = 28,
guidance_scale: float = 7.0,
video_guidance_scale: float = 4.0,
min_guidance_scale: float = 2.0,
use_linear_guidance: bool = False,
alpha: float = 0.5,
negative_prompt: Optional[Union[str, List[str]]]="cartoon style, worst quality, low quality, blurry, absolute black, absolute white, low res, extra limbs, extra digits, misplaced objects, mutated anatomy, monochrome, horror",
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
output_type: Optional[str] = "pil",
save_memory: bool = True,
cpu_offloading: bool = False, # If true, reload device will be cuda.
inference_multigpu: bool = False,
callback: Optional[Callable[[int, int, Dict], None]] = None,
):
if self.sequential_offload_enabled and not cpu_offloading:
print("Warning: overriding cpu_offloading set to false, as it's needed for sequential cpu offload")
cpu_offloading=True
device = self.device if not cpu_offloading else torch.device("cuda")
dtype = self.dtype
if cpu_offloading:
# skip caring about the text encoder here as its about to be used anyways.
if not self.sequential_offload_enabled:
if str(self.dit.device) != "cpu":
print("(dit) Warning: Do not preload pipeline components (i.e. to cuda) with cpu offloading enabled! Otherwise, a second transfer will occur needlessly taking up time.")
self.dit.to("cpu")
torch.cuda.empty_cache()
if str(self.vae.device) != "cpu":
print("(vae) Warning: Do not preload pipeline components (i.e. to cuda) with cpu offloading enabled! Otherwise, a second transfer will occur needlessly taking up time.")
self.vae.to("cpu")
torch.cuda.empty_cache()
width = input_image.width
height = input_image.height
assert temp % self.frame_per_unit == 0, "The frames should be divided by frame_per unit"
if isinstance(prompt, str):
batch_size = 1
prompt = prompt + ", hyper quality, Ultra HD, 8K" # adding this prompt to improve aesthetics
else:
assert isinstance(prompt, list)
batch_size = len(prompt)
prompt = [_ + ", hyper quality, Ultra HD, 8K" for _ in prompt]
if isinstance(num_inference_steps, int):
num_inference_steps = [num_inference_steps] * len(self.stages)
negative_prompt = negative_prompt or ""
# Get the text embeddings
if cpu_offloading and not self.sequential_offload_enabled:
self.text_encoder.to("cuda")
prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = self.text_encoder(prompt, device)
negative_prompt_embeds, negative_prompt_attention_mask, negative_pooled_prompt_embeds = self.text_encoder(negative_prompt, device)
if cpu_offloading:
if not self.sequential_offload_enabled:
self.text_encoder.to("cpu")
self.vae.to("cuda")
torch.cuda.empty_cache()
if use_linear_guidance:
max_guidance_scale = guidance_scale
guidance_scale_list = [max(max_guidance_scale - alpha * t_, min_guidance_scale) for t_ in range(temp+1)]
print(guidance_scale_list)
self._guidance_scale = guidance_scale
self._video_guidance_scale = video_guidance_scale
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
if is_sequence_parallel_initialized():
# sync the prompt embedding across multiple GPUs
sp_group_rank = get_sequence_parallel_group_rank()
global_src_rank = sp_group_rank * get_sequence_parallel_world_size()
torch.distributed.broadcast(prompt_embeds, global_src_rank, group=get_sequence_parallel_group())
torch.distributed.broadcast(pooled_prompt_embeds, global_src_rank, group=get_sequence_parallel_group())
torch.distributed.broadcast(prompt_attention_mask, global_src_rank, group=get_sequence_parallel_group())
# Create the initial random noise
num_channels_latents = (self.dit.config.in_channels // 4) if self.model_name == "pyramid_flux" else self.dit.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
temp,
height,
width,
prompt_embeds.dtype,
device,
generator,
)
temp, height, width = latents.shape[-3], latents.shape[-2], latents.shape[-1]
latents = rearrange(latents, 'b c t h w -> (b t) c h w')
# by defalut, we needs to start from the block noise
for _ in range(len(self.stages)-1):
height //= 2;width //= 2
latents = F.interpolate(latents, size=(height, width), mode='bilinear') * 2
latents = rearrange(latents, '(b t) c h w -> b c t h w', t=temp)
num_units = temp // self.frame_per_unit
stages = self.stages
# encode the image latents
image_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
])
input_image_tensor = image_transform(input_image).unsqueeze(0).unsqueeze(2) # [b c 1 h w]
input_image_latent = (self.vae.encode(input_image_tensor.to(self.vae.device, dtype=self.vae.dtype)).latent_dist.sample() - self.vae_shift_factor) * self.vae_scale_factor # [b c 1 h w]
if is_sequence_parallel_initialized():
# sync the image latent across multiple GPUs
sp_group_rank = get_sequence_parallel_group_rank()
global_src_rank = sp_group_rank * get_sequence_parallel_world_size()
torch.distributed.broadcast(input_image_latent, global_src_rank, group=get_sequence_parallel_group())
generated_latents_list = [input_image_latent] # The generated results
last_generated_latents = input_image_latent
if cpu_offloading:
self.vae.to("cpu")
if not self.sequential_offload_enabled:
self.dit.to("cuda")
torch.cuda.empty_cache()
for unit_index in tqdm(range(1, num_units)):
gc.collect()
torch.cuda.empty_cache()
if callback:
callback(unit_index, num_units)
if use_linear_guidance:
self._guidance_scale = guidance_scale_list[unit_index]
self._video_guidance_scale = guidance_scale_list[unit_index]
# prepare the condition latents
past_condition_latents = []
clean_latents_list = self.get_pyramid_latent(torch.cat(generated_latents_list, dim=2), len(stages) - 1)
for i_s in range(len(stages)):
last_cond_latent = clean_latents_list[i_s][:,:,-self.frame_per_unit:]
stage_input = [torch.cat([last_cond_latent] * 2) if self.do_classifier_free_guidance else last_cond_latent]
# pad the past clean latents
cur_unit_num = unit_index
cur_stage = i_s
cur_unit_ptx = 1
while cur_unit_ptx < cur_unit_num:
cur_stage = max(cur_stage - 1, 0)
if cur_stage == 0:
break
cur_unit_ptx += 1
cond_latents = clean_latents_list[cur_stage][:, :, -(cur_unit_ptx * self.frame_per_unit) : -((cur_unit_ptx - 1) * self.frame_per_unit)]
stage_input.append(torch.cat([cond_latents] * 2) if self.do_classifier_free_guidance else cond_latents)
if cur_stage == 0 and cur_unit_ptx < cur_unit_num:
cond_latents = clean_latents_list[0][:, :, :-(cur_unit_ptx * self.frame_per_unit)]
stage_input.append(torch.cat([cond_latents] * 2) if self.do_classifier_free_guidance else cond_latents)
stage_input = list(reversed(stage_input))
past_condition_latents.append(stage_input)
intermed_latents = self.generate_one_unit(
latents[:,:,(unit_index - 1) * self.frame_per_unit:unit_index * self.frame_per_unit],
past_condition_latents,
prompt_embeds,
prompt_attention_mask,
pooled_prompt_embeds,
num_inference_steps,
height,
width,
self.frame_per_unit,
device,
dtype,
generator,
is_first_frame=False,
)
generated_latents_list.append(intermed_latents[-1])
last_generated_latents = intermed_latents
generated_latents = torch.cat(generated_latents_list, dim=2)
if output_type == "latent":
image = generated_latents
else:
if cpu_offloading:
if not self.sequential_offload_enabled:
self.dit.to("cpu")
self.vae.to("cuda")
torch.cuda.empty_cache()
image = self.decode_latent(generated_latents, save_memory=save_memory, inference_multigpu=inference_multigpu)
if cpu_offloading:
self.vae.to("cpu")
torch.cuda.empty_cache()
# not technically necessary, but returns the pipeline to its original state
return image
@torch.no_grad()
def generate(
self,
prompt: Union[str, List[str]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
temp: int = 1,
num_inference_steps: Optional[Union[int, List[int]]] = 28,
video_num_inference_steps: Optional[Union[int, List[int]]] = 28,
guidance_scale: float = 7.0,
video_guidance_scale: float = 7.0,
min_guidance_scale: float = 2.0,
use_linear_guidance: bool = False,
alpha: float = 0.5,
negative_prompt: Optional[Union[str, List[str]]]="cartoon style, worst quality, low quality, blurry, absolute black, absolute white, low res, extra limbs, extra digits, misplaced objects, mutated anatomy, monochrome, horror",
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
output_type: Optional[str] = "pil",
save_memory: bool = True,
cpu_offloading: bool = False, # If true, reload device will be cuda.
inference_multigpu: bool = False,
callback: Optional[Callable[[int, int, Dict], None]] = None,
):
if self.sequential_offload_enabled and not cpu_offloading:
print("Warning: overriding cpu_offloading set to false, as it's needed for sequential cpu offload")
cpu_offloading=True
device = self.device if not cpu_offloading else torch.device("cuda")
dtype = self.dtype
if cpu_offloading:
# skip caring about the text encoder here as its about to be used anyways.
if not self.sequential_offload_enabled:
if str(self.dit.device) != "cpu":
print("(dit) Warning: Do not preload pipeline components (i.e. to cuda) with cpu offloading enabled! Otherwise, a second transfer will occur needlessly taking up time.")
self.dit.to("cpu")
torch.cuda.empty_cache()
if str(self.vae.device) != "cpu":
print("(vae) Warning: Do not preload pipeline components (i.e. to cuda) with cpu offloading enabled! Otherwise, a second transfer will occur needlessly taking up time.")
self.vae.to("cpu")
torch.cuda.empty_cache()
assert (temp - 1) % self.frame_per_unit == 0, "The frames should be divided by frame_per unit"
if isinstance(prompt, str):
batch_size = 1
prompt = prompt + ", hyper quality, Ultra HD, 8K" # adding this prompt to improve aesthetics
else:
assert isinstance(prompt, list)
batch_size = len(prompt)
prompt = [_ + ", hyper quality, Ultra HD, 8K" for _ in prompt]
if isinstance(num_inference_steps, int):
num_inference_steps = [num_inference_steps] * len(self.stages)
if isinstance(video_num_inference_steps, int):
video_num_inference_steps = [video_num_inference_steps] * len(self.stages)
negative_prompt = negative_prompt or ""
# Get the text embeddings
if cpu_offloading and not self.sequential_offload_enabled:
self.text_encoder.to("cuda")
prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = self.text_encoder(prompt, device)
negative_prompt_embeds, negative_prompt_attention_mask, negative_pooled_prompt_embeds = self.text_encoder(negative_prompt, device)
if cpu_offloading:
if not self.sequential_offload_enabled:
self.text_encoder.to("cpu")
self.dit.to("cuda")
torch.cuda.empty_cache()
if use_linear_guidance:
max_guidance_scale = guidance_scale
# guidance_scale_list = torch.linspace(max_guidance_scale, min_guidance_scale, temp).tolist()
guidance_scale_list = [max(max_guidance_scale - alpha * t_, min_guidance_scale) for t_ in range(temp)]
print(guidance_scale_list)
self._guidance_scale = guidance_scale
self._video_guidance_scale = video_guidance_scale
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
if is_sequence_parallel_initialized():
# sync the prompt embedding across multiple GPUs
sp_group_rank = get_sequence_parallel_group_rank()
global_src_rank = sp_group_rank * get_sequence_parallel_world_size()
torch.distributed.broadcast(prompt_embeds, global_src_rank, group=get_sequence_parallel_group())
torch.distributed.broadcast(pooled_prompt_embeds, global_src_rank, group=get_sequence_parallel_group())
torch.distributed.broadcast(prompt_attention_mask, global_src_rank, group=get_sequence_parallel_group())
# Create the initial random noise
num_channels_latents = (self.dit.config.in_channels // 4) if self.model_name == "pyramid_flux" else self.dit.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
temp,
height,
width,
prompt_embeds.dtype,
device,
generator,
)
temp, height, width = latents.shape[-3], latents.shape[-2], latents.shape[-1]
latents = rearrange(latents, 'b c t h w -> (b t) c h w')
# by default, we needs to start from the block noise
for _ in range(len(self.stages)-1):
height //= 2;width //= 2
latents = F.interpolate(latents, size=(height, width), mode='bilinear') * 2
latents = rearrange(latents, '(b t) c h w -> b c t h w', t=temp)
num_units = 1 + (temp - 1) // self.frame_per_unit
stages = self.stages
generated_latents_list = [] # The generated results
last_generated_latents = None
for unit_index in tqdm(range(num_units)):
gc.collect()
torch.cuda.empty_cache()
if callback:
callback(unit_index, num_units)
if use_linear_guidance:
self._guidance_scale = guidance_scale_list[unit_index]
self._video_guidance_scale = guidance_scale_list[unit_index]
if unit_index == 0:
past_condition_latents = [[] for _ in range(len(stages))]
intermed_latents = self.generate_one_unit(
latents[:,:,:1],
past_condition_latents,
prompt_embeds,
prompt_attention_mask,
pooled_prompt_embeds,
num_inference_steps,
height,
width,
1,
device,
dtype,
generator,
is_first_frame=True,
)
else:
# prepare the condition latents
past_condition_latents = []
clean_latents_list = self.get_pyramid_latent(torch.cat(generated_latents_list, dim=2), len(stages) - 1)
for i_s in range(len(stages)):
last_cond_latent = clean_latents_list[i_s][:,:,-(self.frame_per_unit):]
stage_input = [torch.cat([last_cond_latent] * 2) if self.do_classifier_free_guidance else last_cond_latent]
# pad the past clean latents
cur_unit_num = unit_index
cur_stage = i_s
cur_unit_ptx = 1
while cur_unit_ptx < cur_unit_num:
cur_stage = max(cur_stage - 1, 0)
if cur_stage == 0:
break
cur_unit_ptx += 1
cond_latents = clean_latents_list[cur_stage][:, :, -(cur_unit_ptx * self.frame_per_unit) : -((cur_unit_ptx - 1) * self.frame_per_unit)]
stage_input.append(torch.cat([cond_latents] * 2) if self.do_classifier_free_guidance else cond_latents)
if cur_stage == 0 and cur_unit_ptx < cur_unit_num:
cond_latents = clean_latents_list[0][:, :, :-(cur_unit_ptx * self.frame_per_unit)]
stage_input.append(torch.cat([cond_latents] * 2) if self.do_classifier_free_guidance else cond_latents)
stage_input = list(reversed(stage_input))
past_condition_latents.append(stage_input)
intermed_latents = self.generate_one_unit(
latents[:,:, 1 + (unit_index - 1) * self.frame_per_unit:1 + unit_index * self.frame_per_unit],
past_condition_latents,
prompt_embeds,
prompt_attention_mask,
pooled_prompt_embeds,
video_num_inference_steps,
height,
width,
self.frame_per_unit,
device,
dtype,
generator,
is_first_frame=False,
)
generated_latents_list.append(intermed_latents[-1])
last_generated_latents = intermed_latents
generated_latents = torch.cat(generated_latents_list, dim=2)
if output_type == "latent":
image = generated_latents
else:
if cpu_offloading:
if not self.sequential_offload_enabled:
self.dit.to("cpu")
self.vae.to("cuda")
torch.cuda.empty_cache()
image = self.decode_latent(generated_latents, save_memory=save_memory, inference_multigpu=inference_multigpu)
if cpu_offloading:
self.vae.to("cpu")
torch.cuda.empty_cache()
# not technically necessary, but returns the pipeline to its original state
return image
def decode_latent(self, latents, save_memory=True, inference_multigpu=False):
# only the main process needs vae decoding
if inference_multigpu and get_rank() != 0:
return None
if latents.shape[2] == 1:
latents = (latents / self.vae_scale_factor) + self.vae_shift_factor
else:
latents[:, :, :1] = (latents[:, :, :1] / self.vae_scale_factor) + self.vae_shift_factor
latents[:, :, 1:] = (latents[:, :, 1:] / self.vae_video_scale_factor) + self.vae_video_shift_factor
if save_memory:
# reducing the tile size and temporal chunk window size
image = self.vae.decode(latents, temporal_chunk=True, window_size=1, tile_sample_min_size=256).sample
else:
image = self.vae.decode(latents, temporal_chunk=True, window_size=2, tile_sample_min_size=512).sample
image = image.mul(127.5).add(127.5).clamp(0, 255).byte()
image = rearrange(image, "B C T H W -> (B T) H W C")
image = image.cpu().numpy()
image = self.numpy_to_pil(image)
return image
@staticmethod
def numpy_to_pil(images):
"""
Convert a numpy image or a batch of images to a PIL image.
"""
if images.ndim == 3:
images = images[None, ...]
if images.shape[-1] == 1:
# special case for grayscale (single channel) images
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
else:
pil_images = [Image.fromarray(image) for image in images]
return pil_images
@property
def device(self):
return next(self.dit.parameters()).device
@property
def dtype(self):
return next(self.dit.parameters()).dtype
@property
def guidance_scale(self):
return self._guidance_scale
@property
def video_guidance_scale(self):
return self._video_guidance_scale
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 0
import os
import sys
import torch
import argparse
from PIL import Image
from diffusers.utils import export_to_video
# Add the project root directory to sys.path
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
PROJECT_ROOT = os.path.dirname(SCRIPT_DIR)
if PROJECT_ROOT not in sys.path:
sys.path.insert(0, PROJECT_ROOT)
from pyramid_dit import PyramidDiTForVideoGeneration
from trainer_misc import init_distributed_mode, init_sequence_parallel_group
def get_args():
parser = argparse.ArgumentParser('Pytorch Multi-process Script', add_help=False)
parser.add_argument('--model_name', default='pyramid_mmdit', type=str, help="The model name", choices=["pyramid_flux", "pyramid_mmdit"])
parser.add_argument('--model_dtype', default='bf16', type=str, help="The Model Dtype: bf16")
parser.add_argument('--model_path', required=True, type=str, help='Path to the downloaded checkpoint directory')
parser.add_argument('--variant', default='diffusion_transformer_768p', type=str)
parser.add_argument('--task', default='t2v', type=str, choices=['i2v', 't2v'])
parser.add_argument('--temp', default=16, type=int, help='The generated latent num, num_frames = temp * 8 + 1')
parser.add_argument('--sp_group_size', default=2, type=int, help="The number of GPUs used for inference, should be 2 or 4")
parser.add_argument('--sp_proc_num', default=-1, type=int, help="The number of processes used for video training, default=-1 means using all processes.")
parser.add_argument('--prompt', type=str, required=True, help="Text prompt for video generation")
parser.add_argument('--image_path', type=str, help="Path to the input image for image-to-video")
parser.add_argument('--video_guidance_scale', type=float, default=5.0, help="Video guidance scale")
parser.add_argument('--guidance_scale', type=float, default=9.0, help="Guidance scale for text-to-video")
parser.add_argument('--resolution', type=str, default='768p', choices=['768p', '384p'], help="Model resolution")
parser.add_argument('--output_path', type=str, required=True, help="Path to save the generated video")
return parser.parse_args()
def main():
args = get_args()
# Setup DDP
init_distributed_mode(args)
assert args.world_size == args.sp_group_size, "The sequence parallel size should match DDP world size"
# Enable sequence parallel
init_sequence_parallel_group(args)
device = torch.device('cuda')
rank = args.rank
model_dtype = args.model_dtype
if args.model_name == "pyramid_flux":
assert args.variant != "diffusion_transformer_768p", "The pyramid_flux does not support high resolution now, \
we will release it after finishing training. You can modify the model_name to pyramid_mmdit to support 768p version generation"
model = PyramidDiTForVideoGeneration(
args.model_path,
model_dtype,
model_name=args.model_name,
model_variant=args.variant,
)
model.vae.to(device)
model.dit.to(device)
model.text_encoder.to(device)
model.vae.enable_tiling()
if model_dtype == "bf16":
torch_dtype = torch.bfloat16
elif model_dtype == "fp16":
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
# The video generation config
if args.resolution == '768p':
width = 1280
height = 768
else:
width = 640
height = 384
try:
if args.task == 't2v':
prompt = args.prompt
with torch.no_grad(), torch.cuda.amp.autocast(enabled=(model_dtype != 'fp32'), dtype=torch_dtype):
frames = model.generate(
prompt=prompt,
num_inference_steps=[20, 20, 20],
video_num_inference_steps=[10, 10, 10],
height=height,
width=width,
temp=args.temp,
guidance_scale=args.guidance_scale,
video_guidance_scale=args.video_guidance_scale,
output_type="pil",
save_memory=True,
cpu_offloading=False,
inference_multigpu=True,
)
if rank == 0:
export_to_video(frames, args.output_path, fps=24)
elif args.task == 'i2v':
if not args.image_path:
raise ValueError("Image path is required for image-to-video task")
image = Image.open(args.image_path).convert("RGB")
image = image.resize((width, height))
prompt = args.prompt
with torch.no_grad(), torch.cuda.amp.autocast(enabled=(model_dtype != 'fp32'), dtype=torch_dtype):
frames = model.generate_i2v(
prompt=prompt,
input_image=image,
num_inference_steps=[10, 10, 10],
temp=args.temp,
video_guidance_scale=args.video_guidance_scale,
output_type="pil",
save_memory=True,
cpu_offloading=False,
inference_multigpu=True,
)
if rank == 0:
export_to_video(frames, args.output_path, fps=24)
except Exception as e:
if rank == 0:
print(f"[ERROR] Error during video generation: {e}")
raise
finally:
torch.distributed.barrier()
if __name__ == "__main__":
main()
#!/bin/bash
# Usage:
# ./scripts/app_multigpu_engine.sh GPUS VARIANT MODEL_PATH TASK TEMP GUIDANCE_SCALE VIDEO_GUIDANCE_SCALE RESOLUTION OUTPUT_PATH [IMAGE_PATH] PROMPT
GPUS=$1
VARIANT=$2
MODEL_PATH=$3
TASK=$4
TEMP=$5
GUIDANCE_SCALE=$6
VIDEO_GUIDANCE_SCALE=$7
RESOLUTION=$8
OUTPUT_PATH=$9
shift 9
# Now the remaining arguments are $@
if [ "$TASK" == "t2v" ]; then
PROMPT="$1"
IMAGE_ARG=""
elif [ "$TASK" == "i2v" ]; then
IMAGE_PATH="$1"
PROMPT="$2"
IMAGE_ARG="--image_path $IMAGE_PATH"
else
echo "Invalid task: $TASK"
exit 1
fi
# Get the directory where the script is located
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
# Get the project root directory (parent directory of scripts)
PROJECT_ROOT="$(dirname "$SCRIPT_DIR")"
# Set PYTHONPATH to include the project root directory
export PYTHONPATH="$PROJECT_ROOT:$PYTHONPATH"
# Adjust the path to app_multigpu_engine.py
PYTHON_SCRIPT="$SCRIPT_DIR/app_multigpu_engine.py"
torchrun --nproc_per_node="$GPUS" \
"$PYTHON_SCRIPT" \
--model_path "$MODEL_PATH" \
--variant "$VARIANT" \
--task "$TASK" \
--model_dtype bf16 \
--temp "$TEMP" \
--sp_group_size "$GPUS" \
--guidance_scale "$GUIDANCE_SCALE" \
--video_guidance_scale "$VIDEO_GUIDANCE_SCALE" \
--resolution "$RESOLUTION" \
--output_path "$OUTPUT_PATH" \
--prompt "$PROMPT" \
$IMAGE_ARG
#!/bin/bash
# This script is used for batch extract the text features for video generation training
# Since the T5 will cost a lot of GPU memory, pre-extract the text features will save the training memory
GPUS=2 # The gpu number
MODEL_NAME=pyramid_flux # The model name, `pyramid_flux` or `pyramid_mmdit`
MODEL_PATH=/home/modelzoo/Pyramid-Flow/pyramid_flow_model/pyramid-flow-miniflux # The downloaded ckpt dir. IMPORTANT: It should match with model_name, flux or mmdit (sd3)
ANNO_FILE=annotation/customs/video_text.jsonl # The video-text annotation file path
torchrun --nproc_per_node $GPUS \
tools/extract_text_features.py \
--batch_size 1 \
--model_dtype bf16 \
--model_name $MODEL_NAME \
--model_path $MODEL_PATH \
--anno_file $ANNO_FILE
#!/bin/bash
# This script is used for batch extract the vae latents for video generation training
# Since the video latent extract is very slow, pre-extract the video vae latents will save the training time
GPUS=2 # The gpu number
MODEL_NAME=pyramid_flux # The model name, `pyramid_flux` or `pyramid_mmdit`
VAE_MODEL_PATH=/home/modelzoo/Pyramid-Flow/pyramid_flow_model/pyramid-flow-miniflux/causal_video_vae # The VAE CKPT dir.
ANNO_FILE=annotation/customs/video_text.jsonl # The video annotation file path
WIDTH=640
HEIGHT=384
NUM_FRAMES=121
torchrun --nproc_per_node $GPUS \
tools/extract_video_vae_latents.py \
--batch_size 1 \
--model_dtype bf16 \
--model_path $VAE_MODEL_PATH \
--anno_file $ANNO_FILE \
--width $WIDTH \
--height $HEIGHT \
--num_frames $NUM_FRAMES
#!/bin/bash
# This scripts using 2 gpus to inference.
# Now only supports 2GPUs and 4GPUs for pyramid-flow-sd3; and 2GPUs or 3 GPUs for pyramid-flow-miniflux
# You can set it to 4 to further reduce the generating time
# Requires nproc_per_node == sp_group_size
# Replace the model_path to your downloaded ckpt dir
GPUS=1 # should be 2 or 3
MODEL_NAME=pyramid_flux # or pyramid_mmdit
VARIANT=diffusion_transformer_384p
MODEL_PATH=/home/modelzoo/Pyramid-Flow/pyramid_flow_model/pyramid-flow-miniflux # Replace with your checkpoint path
TASK=i2v # i2v for image-to-video
torchrun --master-port 29001 --nproc_per_node $GPUS \
inference_multigpu.py \
--model_name $MODEL_NAME \
--model_path $MODEL_PATH \
--variant $VARIANT \
--task $TASK \
--model_dtype bf16 \
--temp 16 \
--sp_group_size $GPUS
\ No newline at end of file
#!/bin/bash
# This script is used for Causal VAE Training
# It undergoes a two-stage training
# Stage-1: image and video mixed training
# Stage-2: pure video training, using context parallel to load video with more video frames (up to 257 frames)
# GPUS=8 # The gpu number
# VAE_MODEL_PATH=PATH/vae_ckpt # The vae model dir
# LPIPS_CKPT=vgg_lpips.pth # The LPIPS VGG CKPT path, used for calculating the lpips loss
# OUTPUT_DIR=/PATH/output_dir # The checkpoint saving dir
# IMAGE_ANNO=annotation/image_text.jsonl # The image annotation file path
# VIDEO_ANNO=annotation/video_text.jsonl # The video annotation file path
# RESOLUTION=256 # The training resolution, default is 256
# NUM_FRAMES=17 # x * 8 + 1, the number of video frames
# BATCH_SIZE=2
export HIP_VISIBLE_DEVICES=4,5,6,7
GPUS=4 # The gpu number
VAE_MODEL_PATH=/home/modelzoo/Pyramid-Flow/pyramid_flow_model/pyramid-flow-miniflux/causal_video_vae # The vae model dir
LPIPS_CKPT=/home/modelzoo/Pyramid-Flow/pyramid_flow_model/pyramid-flow-miniflux/vgg.pth # The LPIPS VGG CKPT path, used for calculating the lpips loss
OUTPUT_DIR=./temp_vae # The checkpoint saving dir
IMAGE_ANNO=annotation/customs/vae_image.jsonl # The image annotation file path
VIDEO_ANNO=annotation/customs/vae_video.jsonl # The video annotation file path
RESOLUTION=256 # The training resolution, default is 256
NUM_FRAMES=9 # x * 8 + 1, the number of video frames
BATCH_SIZE=2
# 当使用add_discriminator时,需要将disc_start设置为0,否则会报错
# Stage-1
torchrun --nproc_per_node $GPUS \
train/train_video_vae.py \
--num_workers 6 \
--model_path $VAE_MODEL_PATH \
--model_dtype bf16 \
--lpips_ckpt $LPIPS_CKPT \
--output_dir $OUTPUT_DIR \
--image_anno $IMAGE_ANNO \
--video_anno $VIDEO_ANNO \
--use_image_video_mixed_training \
--image_mix_ratio 0.1 \
--resolution $RESOLUTION \
--max_frames $NUM_FRAMES \
--disc_start 0 \
--kl_weight 1e-12 \
--pixelloss_weight 10.0 \
--perceptual_weight 1.0 \
--disc_weight 0.5 \
--batch_size $BATCH_SIZE \
--opt adamw \
--opt_betas 0.9 0.95 \
--seed 42 \
--weight_decay 1e-3 \
--clip_grad 1.0 \
--lr 1e-4 \
--lr_disc 1e-4 \
--warmup_epochs 0 \
--epochs 10 \
--iters_per_epoch 1000 \
--print_freq 40 \
--save_ckpt_freq 1 \
--add_discriminator
# Stage-2
CONTEXT_SIZE=1 # context parallel size, GPUS % CONTEXT_SIZE == 0
NUM_FRAMES=18 # 17 * CONTEXT_SIZE + 1
VAE_CKPT_PATH=./temp_vae/checkpoint.pth # The stage-1 trained ckpt
torchrun --nproc_per_node $GPUS \
train/train_video_vae.py \
--num_workers 6 \
--model_path $VAE_MODEL_PATH \
--model_dtype bf16 \
--pretrained_vae_weight $VAE_CKPT_PATH \
--use_context_parallel \
--context_size $CONTEXT_SIZE \
--lpips_ckpt $LPIPS_CKPT \
--output_dir $OUTPUT_DIR \
--video_anno $VIDEO_ANNO \
--image_mix_ratio 0.0 \
--resolution $RESOLUTION \
--max_frames $NUM_FRAMES \
--disc_start 0 \
--kl_weight 1e-12 \
--pixelloss_weight 10.0 \
--perceptual_weight 1.0 \
--disc_weight 0.5 \
--batch_size $BATCH_SIZE \
--opt adamw \
--opt_betas 0.9 0.95 \
--seed 42 \
--weight_decay 1e-3 \
--clip_grad 1.0 \
--lr 1e-4 \
--lr_disc 1e-4 \
--warmup_epochs 1 \
--epochs 10 \
--iters_per_epoch 1000 \
--print_freq 40 \
--save_ckpt_freq 1 \
--add_discriminator
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment