Commit 0513d03d authored by jerrrrry's avatar jerrrrry
Browse files

Initial commit

parents
Pipeline #3321 canceled with stages
import os
import torch
__all__ = [
"C_SCALE",
"PROMPT_TEMPLATE",
"MODEL_BASE",
"PRECISIONS",
"NORMALIZATION_TYPE",
"ACTIVATION_TYPE",
"VAE_PATH",
"TEXT_ENCODER_PATH",
"TOKENIZER_PATH",
"TEXT_PROJECTION",
"DATA_TYPE",
"NEGATIVE_PROMPT",
]
PRECISION_TO_TYPE = {
'fp32': torch.float32,
'fp16': torch.float16,
'bf16': torch.bfloat16,
}
# =================== Constant Values =====================
# Computation scale factor, 1P = 1_000_000_000_000_000. Tensorboard will display the value in PetaFLOPS to avoid
# overflow error when tensorboard logging values.
C_SCALE = 1_000_000_000_000_000
# When using decoder-only models, we must provide a prompt template to instruct the text encoder
# on how to generate the text.
# --------------------------------------------------------------------
PROMPT_TEMPLATE_ENCODE = (
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, "
"quantity, text, spatial relationships of the objects and background:<|eot_id|>"
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
)
PROMPT_TEMPLATE_ENCODE_VIDEO = (
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
"1. The main content and theme of the video."
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
"4. background environment, light, style and atmosphere."
"5. camera angles, movements, and transitions used in the video:<|eot_id|>"
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
)
NEGATIVE_PROMPT = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion"
PROMPT_TEMPLATE = {
"dit-llm-encode": {
"template": PROMPT_TEMPLATE_ENCODE,
"crop_start": 36,
},
"dit-llm-encode-video": {
"template": PROMPT_TEMPLATE_ENCODE_VIDEO,
"crop_start": 95,
},
}
# ======================= Model ======================
PRECISIONS = {"fp32", "fp16", "bf16"}
NORMALIZATION_TYPE = {"layer", "rms"}
ACTIVATION_TYPE = {"relu", "silu", "gelu", "gelu_tanh"}
# =================== Model Path =====================
MODEL_BASE = os.getenv("MODEL_BASE", "./ckpts")
# =================== Data =======================
DATA_TYPE = {"image", "video", "image_video"}
# 3D VAE
VAE_PATH = {"884-16c-hy": f"{MODEL_BASE}/hunyuan-video-t2v-720p/vae"}
# Text Encoder
TEXT_ENCODER_PATH = {
"clipL": f"{MODEL_BASE}/text_encoder_2",
"llm": f"{MODEL_BASE}/text_encoder",
}
# Tokenizer
TOKENIZER_PATH = {
"clipL": f"{MODEL_BASE}/text_encoder_2",
"llm": f"{MODEL_BASE}/text_encoder",
}
TEXT_PROJECTION = {
"linear", # Default, an nn.Linear() layer
"single_refiner", # Single TokenRefiner. Refer to LI-DiT
}
from .pipelines import HunyuanVideoPipeline
from .schedulers import FlowMatchDiscreteScheduler
from .pipeline_hunyuan_video import HunyuanVideoPipeline
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#
# Modified from diffusers==0.29.2
#
# ==============================================================================
import inspect
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
import torch
import torch.distributed as dist
import numpy as np
from dataclasses import dataclass
from packaging import version
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
from diffusers.configuration_utils import FrozenDict
from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.models import AutoencoderKL
from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from diffusers.utils.torch_utils import randn_tensor
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.utils import BaseOutput
from ...constants import PRECISION_TO_TYPE
from ...vae.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
from ...text_encoder import TextEncoder
from ...modules import HYVideoDiffusionTransformer
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """"""
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
"""
std_text = noise_pred_text.std(
dim=list(range(1, noise_pred_text.ndim)), keepdim=True
)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# rescale the results from guidance (fixes overexposure)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
noise_cfg = (
guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
)
return noise_cfg
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError(
"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
)
if timesteps is not None:
accepts_timesteps = "timesteps" in set(
inspect.signature(scheduler.set_timesteps).parameters.keys()
)
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(
inspect.signature(scheduler.set_timesteps).parameters.keys()
)
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
@dataclass
class HunyuanVideoPipelineOutput(BaseOutput):
videos: Union[torch.Tensor, np.ndarray]
class HunyuanVideoPipeline(DiffusionPipeline):
r"""
Pipeline for text-to-video generation using HunyuanVideo.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
text_encoder ([`TextEncoder`]):
Frozen text-encoder.
text_encoder_2 ([`TextEncoder`]):
Frozen text-encoder_2.
transformer ([`HYVideoDiffusionTransformer`]):
A `HYVideoDiffusionTransformer` to denoise the encoded video latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents.
"""
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
_optional_components = ["text_encoder_2"]
_exclude_from_cpu_offload = ["transformer"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__(
self,
vae: AutoencoderKL,
text_encoder: TextEncoder,
transformer: HYVideoDiffusionTransformer,
scheduler: KarrasDiffusionSchedulers,
text_encoder_2: Optional[TextEncoder] = None,
progress_bar_config: Dict[str, Any] = None,
args=None,
):
super().__init__()
# ==========================================================================================
if progress_bar_config is None:
progress_bar_config = {}
if not hasattr(self, "_progress_bar_config"):
self._progress_bar_config = {}
self._progress_bar_config.update(progress_bar_config)
self.args = args
# ==========================================================================================
if (
hasattr(scheduler.config, "steps_offset")
and scheduler.config.steps_offset != 1
):
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file"
)
deprecate(
"steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False
)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if (
hasattr(scheduler.config, "clip_sample")
and scheduler.config.clip_sample is True
):
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate(
"clip_sample not set", "1.0.0", deprecation_message, standard_warn=False
)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
transformer=transformer,
scheduler=scheduler,
text_encoder_2=text_encoder_2,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
def encode_prompt(
self,
prompt,
device,
num_videos_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_attention_mask: Optional[torch.Tensor] = None,
lora_scale: Optional[float] = None,
clip_skip: Optional[int] = None,
text_encoder: Optional[TextEncoder] = None,
data_type: Optional[str] = "image",
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
num_videos_per_prompt (`int`):
number of videos that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the video generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
attention_mask (`torch.Tensor`, *optional*):
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
negative_attention_mask (`torch.Tensor`, *optional*):
lora_scale (`float`, *optional*):
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
text_encoder (TextEncoder, *optional*):
data_type (`str`, *optional*):
"""
if text_encoder is None:
text_encoder = self.text_encoder
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(text_encoder.model, lora_scale)
else:
scale_lora_layers(text_encoder.model, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
# textual inversion: process multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, text_encoder.tokenizer)
text_inputs = text_encoder.text2tokens(prompt, data_type=data_type)
if clip_skip is None:
prompt_outputs = text_encoder.encode(
text_inputs, data_type=data_type, device=device
)
prompt_embeds = prompt_outputs.hidden_state
else:
prompt_outputs = text_encoder.encode(
text_inputs,
output_hidden_states=True,
data_type=data_type,
device=device,
)
# Access the `hidden_states` first, that contains a tuple of
# all the hidden states from the encoder layers. Then index into
# the tuple to access the hidden states from the desired layer.
prompt_embeds = prompt_outputs.hidden_states_list[-(clip_skip + 1)]
# We also need to apply the final LayerNorm here to not mess with the
# representations. The `last_hidden_states` that we typically use for
# obtaining the final prompt representations passes through the LayerNorm
# layer.
prompt_embeds = text_encoder.model.text_model.final_layer_norm(
prompt_embeds
)
attention_mask = prompt_outputs.attention_mask
if attention_mask is not None:
attention_mask = attention_mask.to(device)
bs_embed, seq_len = attention_mask.shape
attention_mask = attention_mask.repeat(1, num_videos_per_prompt)
attention_mask = attention_mask.view(
bs_embed * num_videos_per_prompt, seq_len
)
if text_encoder is not None:
prompt_embeds_dtype = text_encoder.dtype
elif self.transformer is not None:
prompt_embeds_dtype = self.transformer.dtype
else:
prompt_embeds_dtype = prompt_embeds.dtype
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
if prompt_embeds.ndim == 2:
bs_embed, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt)
prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, -1)
else:
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(
bs_embed * num_videos_per_prompt, seq_len, -1
)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
# textual inversion: process multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(
uncond_tokens, text_encoder.tokenizer
)
# max_length = prompt_embeds.shape[1]
uncond_input = text_encoder.text2tokens(uncond_tokens, data_type=data_type)
negative_prompt_outputs = text_encoder.encode(
uncond_input, data_type=data_type, device=device
)
negative_prompt_embeds = negative_prompt_outputs.hidden_state
negative_attention_mask = negative_prompt_outputs.attention_mask
if negative_attention_mask is not None:
negative_attention_mask = negative_attention_mask.to(device)
_, seq_len = negative_attention_mask.shape
negative_attention_mask = negative_attention_mask.repeat(
1, num_videos_per_prompt
)
negative_attention_mask = negative_attention_mask.view(
batch_size * num_videos_per_prompt, seq_len
)
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(
dtype=prompt_embeds_dtype, device=device
)
if negative_prompt_embeds.ndim == 2:
negative_prompt_embeds = negative_prompt_embeds.repeat(
1, num_videos_per_prompt
)
negative_prompt_embeds = negative_prompt_embeds.view(
batch_size * num_videos_per_prompt, -1
)
else:
negative_prompt_embeds = negative_prompt_embeds.repeat(
1, num_videos_per_prompt, 1
)
negative_prompt_embeds = negative_prompt_embeds.view(
batch_size * num_videos_per_prompt, seq_len, -1
)
if text_encoder is not None:
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(text_encoder.model, lora_scale)
return (
prompt_embeds,
negative_prompt_embeds,
attention_mask,
negative_attention_mask,
)
def decode_latents(self, latents, enable_tiling=True):
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
latents = 1 / self.vae.config.scaling_factor * latents
if enable_tiling:
self.vae.enable_tiling()
image = self.vae.decode(latents, return_dict=False)[0]
else:
image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
if image.ndim == 4:
image = image.cpu().permute(0, 2, 3, 1).float()
else:
image = image.cpu().float()
return image
def prepare_extra_func_kwargs(self, func, kwargs):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
extra_step_kwargs = {}
for k, v in kwargs.items():
accepts = k in set(inspect.signature(func).parameters.keys())
if accepts:
extra_step_kwargs[k] = v
return extra_step_kwargs
def check_inputs(
self,
prompt,
height,
width,
video_length,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
vae_ver="88-4c-sd",
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(
f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
)
if video_length is not None:
if "884" in vae_ver:
if video_length != 1 and (video_length - 1) % 4 != 0:
raise ValueError(
f"`video_length` has to be 1 or a multiple of 4 but is {video_length}."
)
elif "888" in vae_ver:
if video_length != 1 and (video_length - 1) % 8 != 0:
raise ValueError(
f"`video_length` has to be 1 or a multiple of 8 but is {video_length}."
)
if callback_steps is not None and (
not isinstance(callback_steps, int) or callback_steps <= 0
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs
for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (
not isinstance(prompt, str) and not isinstance(prompt, list)
):
raise ValueError(
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
)
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
def prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
video_length,
dtype,
device,
generator,
latents=None,
):
shape = (
batch_size,
num_channels_latents,
video_length,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(
shape, generator=generator, device=device, dtype=dtype
)
else:
latents = latents.to(device)
# Check existence to make it compatible with FlowMatchEulerDiscreteScheduler
if hasattr(self.scheduler, "init_noise_sigma"):
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
def get_guidance_scale_embedding(
self,
w: torch.Tensor,
embedding_dim: int = 512,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
Args:
w (`torch.Tensor`):
Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
embedding_dim (`int`, *optional*, defaults to 512):
Dimension of the embeddings to generate.
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
Data type of the generated embeddings.
Returns:
`torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
"""
assert len(w.shape) == 1
w = w * 1000.0
half_dim = embedding_dim // 2
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
emb = w.to(dtype)[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1))
assert emb.shape == (w.shape[0], embedding_dim)
return emb
@property
def guidance_scale(self):
return self._guidance_scale
@property
def guidance_rescale(self):
return self._guidance_rescale
@property
def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
# return self._guidance_scale > 1 and self.transformer.config.time_cond_proj_dim is None
return self._guidance_scale > 1
@property
def cross_attention_kwargs(self):
return self._cross_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]],
height: int,
width: int,
video_length: int,
data_type: str = "video",
num_inference_steps: int = 50,
timesteps: List[int] = None,
sigmas: List[float] = None,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_videos_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_attention_mask: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[
Union[
Callable[[int, int, Dict], None],
PipelineCallback,
MultiPipelineCallbacks,
]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
vae_ver: str = "88-4c-sd",
enable_tiling: bool = False,
n_tokens: Optional[int] = None,
embedded_guidance_scale: Optional[float] = None,
**kwargs,
):
r"""
The call function to the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
height (`int`):
The height in pixels of the generated image.
width (`int`):
The width in pixels of the generated image.
video_length (`int`):
The number of frames in the generated video.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a
plain tuple.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
guidance_rescale (`float`, *optional*, defaults to 0.0):
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
using zero terminal SNR.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
Examples:
Returns:
[`~HunyuanVideoPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned,
otherwise a `tuple` is returned where the first element is a list with the generated images and the
second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content.
"""
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
if callback is not None:
deprecate(
"callback",
"1.0.0",
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)
if callback_steps is not None:
deprecate(
"callback_steps",
"1.0.0",
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# 0. Default height and width to unet
# height = height or self.transformer.config.sample_size * self.vae_scale_factor
# width = width or self.transformer.config.sample_size * self.vae_scale_factor
# to deal with lora scaling and other possible forward hooks
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
height,
width,
video_length,
callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
callback_on_step_end_tensor_inputs,
vae_ver=vae_ver,
)
self._guidance_scale = guidance_scale
self._guidance_rescale = guidance_rescale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = torch.device(f"cuda:{dist.get_rank()}") if dist.is_initialized() else self._execution_device
# 3. Encode input prompt
lora_scale = (
self.cross_attention_kwargs.get("scale", None)
if self.cross_attention_kwargs is not None
else None
)
(
prompt_embeds,
negative_prompt_embeds,
prompt_mask,
negative_prompt_mask,
) = self.encode_prompt(
prompt,
device,
num_videos_per_prompt,
self.do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
attention_mask=attention_mask,
negative_prompt_embeds=negative_prompt_embeds,
negative_attention_mask=negative_attention_mask,
lora_scale=lora_scale,
clip_skip=self.clip_skip,
data_type=data_type,
)
if self.text_encoder_2 is not None:
(
prompt_embeds_2,
negative_prompt_embeds_2,
prompt_mask_2,
negative_prompt_mask_2,
) = self.encode_prompt(
prompt,
device,
num_videos_per_prompt,
self.do_classifier_free_guidance,
negative_prompt,
prompt_embeds=None,
attention_mask=None,
negative_prompt_embeds=None,
negative_attention_mask=None,
lora_scale=lora_scale,
clip_skip=self.clip_skip,
text_encoder=self.text_encoder_2,
data_type=data_type,
)
else:
prompt_embeds_2 = None
negative_prompt_embeds_2 = None
prompt_mask_2 = None
negative_prompt_mask_2 = None
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if prompt_mask is not None:
prompt_mask = torch.cat([negative_prompt_mask, prompt_mask])
if prompt_embeds_2 is not None:
prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])
if prompt_mask_2 is not None:
prompt_mask_2 = torch.cat([negative_prompt_mask_2, prompt_mask_2])
# 4. Prepare timesteps
extra_set_timesteps_kwargs = self.prepare_extra_func_kwargs(
self.scheduler.set_timesteps, {"n_tokens": n_tokens}
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
timesteps,
sigmas,
**extra_set_timesteps_kwargs,
)
if "884" in vae_ver:
video_length = (video_length - 1) // 4 + 1
elif "888" in vae_ver:
video_length = (video_length - 1) // 8 + 1
else:
video_length = video_length
# 5. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
num_channels_latents,
height,
width,
video_length,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_func_kwargs(
self.scheduler.step,
{"generator": generator, "eta": eta},
)
target_dtype = PRECISION_TO_TYPE[self.args.precision]
autocast_enabled = (
target_dtype != torch.float32
) and not self.args.disable_autocast
vae_dtype = PRECISION_TO_TYPE[self.args.vae_precision]
vae_autocast_enabled = (
vae_dtype != torch.float32
) and not self.args.disable_autocast
# 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
# if is_progress_bar:
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = (
torch.cat([latents] * 2)
if self.do_classifier_free_guidance
else latents
)
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, t
)
t_expand = t.repeat(latent_model_input.shape[0])
guidance_expand = (
torch.tensor(
[embedded_guidance_scale] * latent_model_input.shape[0],
dtype=torch.float32,
device=device,
).to(target_dtype)
* 1000.0
if embedded_guidance_scale is not None
else None
)
# predict the noise residual
with torch.autocast(
device_type="cuda", dtype=target_dtype, enabled=autocast_enabled
):
noise_pred = self.transformer( # For an input image (129, 192, 336) (1, 256, 256)
latent_model_input, # [2, 16, 33, 24, 42]
t_expand, # [2]
text_states=prompt_embeds, # [2, 256, 4096]
text_mask=prompt_mask, # [2, 256]
text_states_2=prompt_embeds_2, # [2, 768]
freqs_cos=freqs_cis[0], # [seqlen, head_dim]
freqs_sin=freqs_cis[1], # [seqlen, head_dim]
guidance=guidance_expand,
return_dict=True,
)[
"x"
]
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (
noise_pred_text - noise_pred_uncond
)
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(
noise_pred,
noise_pred_text,
guidance_rescale=self.guidance_rescale,
)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(
noise_pred, t, latents, **extra_step_kwargs, return_dict=False
)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop(
"negative_prompt_embeds", negative_prompt_embeds
)
# call the callback, if provided
if i == len(timesteps) - 1 or (
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
):
if progress_bar is not None:
progress_bar.update()
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if not output_type == "latent":
expand_temporal_dim = False
if len(latents.shape) == 4:
if isinstance(self.vae, AutoencoderKLCausal3D):
latents = latents.unsqueeze(2)
expand_temporal_dim = True
elif len(latents.shape) == 5:
pass
else:
raise ValueError(
f"Only support latents with shape (b, c, h, w) or (b, c, f, h, w), but got {latents.shape}."
)
if (
hasattr(self.vae.config, "shift_factor")
and self.vae.config.shift_factor
):
latents = (
latents / self.vae.config.scaling_factor
+ self.vae.config.shift_factor
)
else:
latents = latents / self.vae.config.scaling_factor
with torch.autocast(
device_type="cuda", dtype=vae_dtype, enabled=vae_autocast_enabled
):
if enable_tiling:
self.vae.enable_tiling()
image = self.vae.decode(
latents, return_dict=False, generator=generator
)[0]
else:
image = self.vae.decode(
latents, return_dict=False, generator=generator
)[0]
if expand_temporal_dim or image.shape[2] == 1:
image = image.squeeze(2)
else:
image = latents
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image = image.cpu().float()
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return image
return HunyuanVideoPipelineOutput(videos=image)
from .scheduling_flow_match_discrete import FlowMatchDiscreteScheduler
# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#
# Modified from diffusers==0.29.2
#
# ==============================================================================
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils import BaseOutput, logging
from diffusers.schedulers.scheduling_utils import SchedulerMixin
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class FlowMatchDiscreteSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's `step` function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
"""
prev_sample: torch.FloatTensor
class FlowMatchDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
Euler scheduler.
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
methods the library implements for all schedulers such as loading and saving.
Args:
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
timestep_spacing (`str`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
shift (`float`, defaults to 1.0):
The shift value for the timestep schedule.
reverse (`bool`, defaults to `True`):
Whether to reverse the timestep schedule.
"""
_compatibles = []
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
shift: float = 1.0,
reverse: bool = True,
solver: str = "euler",
n_tokens: Optional[int] = None,
):
sigmas = torch.linspace(1, 0, num_train_timesteps + 1)
if not reverse:
sigmas = sigmas.flip(0)
self.sigmas = sigmas
# the value fed to model
self.timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.float32)
self._step_index = None
self._begin_index = None
self.supported_solver = ["euler"]
if solver not in self.supported_solver:
raise ValueError(
f"Solver {solver} not supported. Supported solvers: {self.supported_solver}"
)
@property
def step_index(self):
"""
The index counter for current timestep. It will increase 1 after each scheduler step.
"""
return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def _sigma_to_t(self, sigma):
return sigma * self.config.num_train_timesteps
def set_timesteps(
self,
num_inference_steps: int,
device: Union[str, torch.device] = None,
n_tokens: int = None,
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
n_tokens (`int`, *optional*):
Number of tokens in the input sequence.
"""
self.num_inference_steps = num_inference_steps
sigmas = torch.linspace(1, 0, num_inference_steps + 1)
sigmas = self.sd3_time_shift(sigmas)
if not self.config.reverse:
sigmas = 1 - sigmas
self.sigmas = sigmas
self.timesteps = (sigmas[:-1] * self.config.num_train_timesteps).to(
dtype=torch.float32, device=device
)
# Reset step index
self._step_index = None
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
indices = (schedule_timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
pos = 1 if len(indices) > 1 else 0
return indices[pos].item()
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def scale_model_input(
self, sample: torch.Tensor, timestep: Optional[int] = None
) -> torch.Tensor:
return sample
def sd3_time_shift(self, t: torch.Tensor):
return (self.config.shift * t) / (1 + (self.config.shift - 1) * t)
def step(
self,
model_output: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
sample: torch.FloatTensor,
return_dict: bool = True,
) -> Union[FlowMatchDiscreteSchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`):
The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
n_tokens (`int`, *optional*):
Number of tokens in the input sequence.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
tuple.
Returns:
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
returned, otherwise a tuple is returned where the first element is the sample tensor.
"""
if (
isinstance(timestep, int)
or isinstance(timestep, torch.IntTensor)
or isinstance(timestep, torch.LongTensor)
):
raise ValueError(
(
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
" one of the `scheduler.timesteps` as a timestep."
),
)
if self.step_index is None:
self._init_step_index(timestep)
# Upcast to avoid precision issues when computing prev_sample
sample = sample.to(torch.float32)
dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index]
if self.config.solver == "euler":
prev_sample = sample + model_output.to(torch.float32) * dt
else:
raise ValueError(
f"Solver {self.config.solver} not supported. Supported solvers: {self.supported_solver}"
)
# upon completion increase step index by one
self._step_index += 1
if not return_dict:
return (prev_sample,)
return FlowMatchDiscreteSchedulerOutput(prev_sample=prev_sample)
def __len__(self):
return self.config.num_train_timesteps
import os
import time
import random
import functools
from typing import List, Optional, Tuple, Union
from pathlib import Path
from loguru import logger
import torch
import torch.distributed as dist
from hyvideo.constants import PROMPT_TEMPLATE, NEGATIVE_PROMPT, PRECISION_TO_TYPE
from hyvideo.vae import load_vae
from hyvideo.modules import load_model
from hyvideo.text_encoder import TextEncoder
from hyvideo.utils.data_utils import align_to
from hyvideo.modules.posemb_layers import get_nd_rotary_pos_embed
from hyvideo.modules.fp8_optimization import convert_fp8_linear
from hyvideo.diffusion.schedulers import FlowMatchDiscreteScheduler
from hyvideo.diffusion.pipelines import HunyuanVideoPipeline
try:
import xfuser
from xfuser.core.distributed import (
get_sequence_parallel_world_size,
get_sequence_parallel_rank,
get_sp_group,
initialize_model_parallel,
init_distributed_environment
)
except:
xfuser = None
get_sequence_parallel_world_size = None
get_sequence_parallel_rank = None
get_sp_group = None
initialize_model_parallel = None
init_distributed_environment = None
def parallelize_transformer(pipe):
transformer = pipe.transformer
original_forward = transformer.forward
@functools.wraps(transformer.__class__.forward)
def new_forward(
self,
x: torch.Tensor,
t: torch.Tensor, # Should be in range(0, 1000).
text_states: torch.Tensor = None,
text_mask: torch.Tensor = None, # Now we don't use it.
text_states_2: Optional[torch.Tensor] = None, # Text embedding for modulation.
freqs_cos: Optional[torch.Tensor] = None,
freqs_sin: Optional[torch.Tensor] = None,
guidance: torch.Tensor = None, # Guidance for modulation, should be cfg_scale x 1000.
return_dict: bool = True,
):
if x.shape[-2] // 2 % get_sequence_parallel_world_size() == 0:
# try to split x by height
split_dim = -2
elif x.shape[-1] // 2 % get_sequence_parallel_world_size() == 0:
# try to split x by width
split_dim = -1
else:
raise ValueError(f"Cannot split video sequence into ulysses_degree x ring_degree ({get_sequence_parallel_world_size()}) parts evenly")
# patch sizes for the temporal, height, and width dimensions are 1, 2, and 2.
temporal_size, h, w = x.shape[2], x.shape[3] // 2, x.shape[4] // 2
x = torch.chunk(x, get_sequence_parallel_world_size(),dim=split_dim)[get_sequence_parallel_rank()]
dim_thw = freqs_cos.shape[-1]
freqs_cos = freqs_cos.reshape(temporal_size, h, w, dim_thw)
freqs_cos = torch.chunk(freqs_cos, get_sequence_parallel_world_size(),dim=split_dim - 1)[get_sequence_parallel_rank()]
freqs_cos = freqs_cos.reshape(-1, dim_thw)
dim_thw = freqs_sin.shape[-1]
freqs_sin = freqs_sin.reshape(temporal_size, h, w, dim_thw)
freqs_sin = torch.chunk(freqs_sin, get_sequence_parallel_world_size(),dim=split_dim - 1)[get_sequence_parallel_rank()]
freqs_sin = freqs_sin.reshape(-1, dim_thw)
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
for block in transformer.double_blocks + transformer.single_blocks:
block.hybrid_seq_parallel_attn = xFuserLongContextAttention()
output = original_forward(
x,
t,
text_states,
text_mask,
text_states_2,
freqs_cos,
freqs_sin,
guidance,
return_dict,
)
return_dict = not isinstance(output, tuple)
sample = output["x"]
sample = get_sp_group().all_gather(sample, dim=split_dim)
output["x"] = sample
return output
new_forward = new_forward.__get__(transformer)
transformer.forward = new_forward
class Inference(object):
def __init__(
self,
args,
vae,
vae_kwargs,
text_encoder,
model,
text_encoder_2=None,
pipeline=None,
use_cpu_offload=False,
device=None,
logger=None,
parallel_args=None,
):
self.vae = vae
self.vae_kwargs = vae_kwargs
self.text_encoder = text_encoder
self.text_encoder_2 = text_encoder_2
self.model = model
self.pipeline = pipeline
self.use_cpu_offload = use_cpu_offload
self.args = args
self.device = (
device
if device is not None
else "cuda"
if torch.cuda.is_available()
else "cpu"
)
self.logger = logger
self.parallel_args = parallel_args
@classmethod
def from_pretrained(cls, pretrained_model_path, args, device=None, **kwargs):
"""
Initialize the Inference pipeline.
Args:
pretrained_model_path (str or pathlib.Path): The model path, including t2v, text encoder and vae checkpoints.
args (argparse.Namespace): The arguments for the pipeline.
device (int): The device for inference. Default is 0.
"""
# ========================================================================
logger.info(f"Got text-to-video model root path: {pretrained_model_path}")
# ==================== Initialize Distributed Environment ================
if args.ulysses_degree > 1 or args.ring_degree > 1:
assert xfuser is not None, \
"Ulysses Attention and Ring Attention requires xfuser package."
assert args.use_cpu_offload is False, \
"Cannot enable use_cpu_offload in the distributed environment."
dist.init_process_group("nccl")
assert dist.get_world_size() == args.ring_degree * args.ulysses_degree, \
"number of GPUs should be equal to ring_degree * ulysses_degree."
init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
initialize_model_parallel(
sequence_parallel_degree=dist.get_world_size(),
ring_degree=args.ring_degree,
ulysses_degree=args.ulysses_degree,
)
device = torch.device(f"cuda:{os.environ['LOCAL_RANK']}")
else:
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
parallel_args = {"ulysses_degree": args.ulysses_degree, "ring_degree": args.ring_degree}
# ======================== Get the args path =============================
# Disable gradient
torch.set_grad_enabled(False)
# =========================== Build main model ===========================
logger.info("Building model...")
factor_kwargs = {"device": device, "dtype": PRECISION_TO_TYPE[args.precision]}
in_channels = args.latent_channels
out_channels = args.latent_channels
model = load_model(
args,
in_channels=in_channels,
out_channels=out_channels,
factor_kwargs=factor_kwargs,
)
if args.use_fp8:
convert_fp8_linear(model, args.dit_weight, original_dtype=PRECISION_TO_TYPE[args.precision])
model = model.to(device)
model = Inference.load_state_dict(args, model, pretrained_model_path)
model.eval()
# ============================= Build extra models ========================
# VAE
vae, _, s_ratio, t_ratio = load_vae(
args.vae,
args.vae_precision,
logger=logger,
device=device if not args.use_cpu_offload else "cpu",
)
vae_kwargs = {"s_ratio": s_ratio, "t_ratio": t_ratio}
# Text encoder
if args.prompt_template_video is not None:
crop_start = PROMPT_TEMPLATE[args.prompt_template_video].get(
"crop_start", 0
)
elif args.prompt_template is not None:
crop_start = PROMPT_TEMPLATE[args.prompt_template].get("crop_start", 0)
else:
crop_start = 0
max_length = args.text_len + crop_start
# prompt_template
prompt_template = (
PROMPT_TEMPLATE[args.prompt_template]
if args.prompt_template is not None
else None
)
# prompt_template_video
prompt_template_video = (
PROMPT_TEMPLATE[args.prompt_template_video]
if args.prompt_template_video is not None
else None
)
text_encoder = TextEncoder(
text_encoder_type=args.text_encoder,
max_length=max_length,
text_encoder_precision=args.text_encoder_precision,
tokenizer_type=args.tokenizer,
prompt_template=prompt_template,
prompt_template_video=prompt_template_video,
hidden_state_skip_layer=args.hidden_state_skip_layer,
apply_final_norm=args.apply_final_norm,
reproduce=args.reproduce,
logger=logger,
device=device if not args.use_cpu_offload else "cpu",
)
text_encoder_2 = None
if args.text_encoder_2 is not None:
text_encoder_2 = TextEncoder(
text_encoder_type=args.text_encoder_2,
max_length=args.text_len_2,
text_encoder_precision=args.text_encoder_precision_2,
tokenizer_type=args.tokenizer_2,
reproduce=args.reproduce,
logger=logger,
device=device if not args.use_cpu_offload else "cpu",
)
return cls(
args=args,
vae=vae,
vae_kwargs=vae_kwargs,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
model=model,
use_cpu_offload=args.use_cpu_offload,
device=device,
logger=logger,
parallel_args=parallel_args
)
@staticmethod
def load_state_dict(args, model, pretrained_model_path):
load_key = args.load_key
dit_weight = Path(args.dit_weight)
if dit_weight is None:
model_dir = pretrained_model_path / f"t2v_{args.model_resolution}"
files = list(model_dir.glob("*.pt"))
if len(files) == 0:
raise ValueError(f"No model weights found in {model_dir}")
if str(files[0]).startswith("pytorch_model_"):
model_path = dit_weight / f"pytorch_model_{load_key}.pt"
bare_model = True
elif any(str(f).endswith("_model_states.pt") for f in files):
files = [f for f in files if str(f).endswith("_model_states.pt")]
model_path = files[0]
if len(files) > 1:
logger.warning(
f"Multiple model weights found in {dit_weight}, using {model_path}"
)
bare_model = False
else:
raise ValueError(
f"Invalid model path: {dit_weight} with unrecognized weight format: "
f"{list(map(str, files))}. When given a directory as --dit-weight, only "
f"`pytorch_model_*.pt`(provided by HunyuanDiT official) and "
f"`*_model_states.pt`(saved by deepspeed) can be parsed. If you want to load a "
f"specific weight file, please provide the full path to the file."
)
else:
if dit_weight.is_dir():
files = list(dit_weight.glob("*.pt"))
if len(files) == 0:
raise ValueError(f"No model weights found in {dit_weight}")
if str(files[0]).startswith("pytorch_model_"):
model_path = dit_weight / f"pytorch_model_{load_key}.pt"
bare_model = True
elif any(str(f).endswith("_model_states.pt") for f in files):
files = [f for f in files if str(f).endswith("_model_states.pt")]
model_path = files[0]
if len(files) > 1:
logger.warning(
f"Multiple model weights found in {dit_weight}, using {model_path}"
)
bare_model = False
else:
raise ValueError(
f"Invalid model path: {dit_weight} with unrecognized weight format: "
f"{list(map(str, files))}. When given a directory as --dit-weight, only "
f"`pytorch_model_*.pt`(provided by HunyuanDiT official) and "
f"`*_model_states.pt`(saved by deepspeed) can be parsed. If you want to load a "
f"specific weight file, please provide the full path to the file."
)
elif dit_weight.is_file():
model_path = dit_weight
bare_model = "unknown"
else:
raise ValueError(f"Invalid model path: {dit_weight}")
if not model_path.exists():
raise ValueError(f"model_path not exists: {model_path}")
logger.info(f"Loading torch model {model_path}...")
state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
if bare_model == "unknown" and ("ema" in state_dict or "module" in state_dict):
bare_model = False
if bare_model is False:
if load_key in state_dict:
state_dict = state_dict[load_key]
else:
raise KeyError(
f"Missing key: `{load_key}` in the checkpoint: {model_path}. The keys in the checkpoint "
f"are: {list(state_dict.keys())}."
)
model.load_state_dict(state_dict, strict=True)
return model
@staticmethod
def parse_size(size):
if isinstance(size, int):
size = [size]
if not isinstance(size, (list, tuple)):
raise ValueError(f"Size must be an integer or (height, width), got {size}.")
if len(size) == 1:
size = [size[0], size[0]]
if len(size) != 2:
raise ValueError(f"Size must be an integer or (height, width), got {size}.")
return size
class HunyuanVideoSampler(Inference):
def __init__(
self,
args,
vae,
vae_kwargs,
text_encoder,
model,
text_encoder_2=None,
pipeline=None,
use_cpu_offload=False,
device=0,
logger=None,
parallel_args=None
):
super().__init__(
args,
vae,
vae_kwargs,
text_encoder,
model,
text_encoder_2=text_encoder_2,
pipeline=pipeline,
use_cpu_offload=use_cpu_offload,
device=device,
logger=logger,
parallel_args=parallel_args
)
self.pipeline = self.load_diffusion_pipeline(
args=args,
vae=self.vae,
text_encoder=self.text_encoder,
text_encoder_2=self.text_encoder_2,
model=self.model,
device=self.device,
)
self.default_negative_prompt = NEGATIVE_PROMPT
if self.parallel_args['ulysses_degree'] > 1 or self.parallel_args['ring_degree'] > 1:
parallelize_transformer(self.pipeline)
def load_diffusion_pipeline(
self,
args,
vae,
text_encoder,
text_encoder_2,
model,
scheduler=None,
device=None,
progress_bar_config=None,
data_type="video",
):
"""Load the denoising scheduler for inference."""
if scheduler is None:
if args.denoise_type == "flow":
scheduler = FlowMatchDiscreteScheduler(
shift=args.flow_shift,
reverse=args.flow_reverse,
solver=args.flow_solver,
)
else:
raise ValueError(f"Invalid denoise type {args.denoise_type}")
pipeline = HunyuanVideoPipeline(
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
transformer=model,
scheduler=scheduler,
progress_bar_config=progress_bar_config,
args=args,
)
if self.use_cpu_offload:
pipeline.enable_sequential_cpu_offload()
else:
pipeline = pipeline.to(device)
return pipeline
def get_rotary_pos_embed(self, video_length, height, width):
target_ndim = 3
ndim = 5 - 2
# 884
if "884" in self.args.vae:
latents_size = [(video_length - 1) // 4 + 1, height // 8, width // 8]
elif "888" in self.args.vae:
latents_size = [(video_length - 1) // 8 + 1, height // 8, width // 8]
else:
latents_size = [video_length, height // 8, width // 8]
if isinstance(self.model.patch_size, int):
assert all(s % self.model.patch_size == 0 for s in latents_size), (
f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), "
f"but got {latents_size}."
)
rope_sizes = [s // self.model.patch_size for s in latents_size]
elif isinstance(self.model.patch_size, list):
assert all(
s % self.model.patch_size[idx] == 0
for idx, s in enumerate(latents_size)
), (
f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), "
f"but got {latents_size}."
)
rope_sizes = [
s // self.model.patch_size[idx] for idx, s in enumerate(latents_size)
]
if len(rope_sizes) != target_ndim:
rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis
head_dim = self.model.hidden_size // self.model.heads_num
rope_dim_list = self.model.rope_dim_list
if rope_dim_list is None:
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
assert (
sum(rope_dim_list) == head_dim
), "sum(rope_dim_list) should equal to head_dim of attention layer"
freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
rope_dim_list,
rope_sizes,
theta=self.args.rope_theta,
use_real=True,
theta_rescale_factor=1,
)
return freqs_cos, freqs_sin
@torch.no_grad()
def predict(
self,
prompt,
height=192,
width=336,
video_length=129,
seed=None,
negative_prompt=None,
infer_steps=50,
guidance_scale=6.0,
flow_shift=5.0,
embedded_guidance_scale=None,
batch_size=1,
num_videos_per_prompt=1,
**kwargs,
):
"""
Predict the image/video from the given text.
Args:
prompt (str or List[str]): The input text.
kwargs:
height (int): The height of the output video. Default is 192.
width (int): The width of the output video. Default is 336.
video_length (int): The frame number of the output video. Default is 129.
seed (int or List[str]): The random seed for the generation. Default is a random integer.
negative_prompt (str or List[str]): The negative text prompt. Default is an empty string.
guidance_scale (float): The guidance scale for the generation. Default is 6.0.
num_images_per_prompt (int): The number of images per prompt. Default is 1.
infer_steps (int): The number of inference steps. Default is 100.
"""
out_dict = dict()
# ========================================================================
# Arguments: seed
# ========================================================================
if isinstance(seed, torch.Tensor):
seed = seed.tolist()
if seed is None:
seeds = [
random.randint(0, 1_000_000)
for _ in range(batch_size * num_videos_per_prompt)
]
elif isinstance(seed, int):
seeds = [
seed + i
for _ in range(batch_size)
for i in range(num_videos_per_prompt)
]
elif isinstance(seed, (list, tuple)):
if len(seed) == batch_size:
seeds = [
int(seed[i]) + j
for i in range(batch_size)
for j in range(num_videos_per_prompt)
]
elif len(seed) == batch_size * num_videos_per_prompt:
seeds = [int(s) for s in seed]
else:
raise ValueError(
f"Length of seed must be equal to number of prompt(batch_size) or "
f"batch_size * num_videos_per_prompt ({batch_size} * {num_videos_per_prompt}), got {seed}."
)
else:
raise ValueError(
f"Seed must be an integer, a list of integers, or None, got {seed}."
)
generator = [torch.Generator(self.device).manual_seed(seed) for seed in seeds]
out_dict["seeds"] = seeds
# ========================================================================
# Arguments: target_width, target_height, target_video_length
# ========================================================================
if width <= 0 or height <= 0 or video_length <= 0:
raise ValueError(
f"`height` and `width` and `video_length` must be positive integers, got height={height}, width={width}, video_length={video_length}"
)
if (video_length - 1) % 4 != 0:
raise ValueError(
f"`video_length-1` must be a multiple of 4, got {video_length}"
)
logger.info(
f"Input (height, width, video_length) = ({height}, {width}, {video_length})"
)
target_height = align_to(height, 16)
target_width = align_to(width, 16)
target_video_length = video_length
out_dict["size"] = (target_height, target_width, target_video_length)
# ========================================================================
# Arguments: prompt, new_prompt, negative_prompt
# ========================================================================
if not isinstance(prompt, str):
raise TypeError(f"`prompt` must be a string, but got {type(prompt)}")
prompt = [prompt.strip()]
# negative prompt
if negative_prompt is None or negative_prompt == "":
negative_prompt = self.default_negative_prompt
if guidance_scale == 1.0:
negative_prompt = ""
if not isinstance(negative_prompt, str):
raise TypeError(
f"`negative_prompt` must be a string, but got {type(negative_prompt)}"
)
negative_prompt = [negative_prompt.strip()]
# ========================================================================
# Scheduler
# ========================================================================
scheduler = FlowMatchDiscreteScheduler(
shift=flow_shift,
reverse=self.args.flow_reverse,
solver=self.args.flow_solver
)
self.pipeline.scheduler = scheduler
# ========================================================================
# Build Rope freqs
# ========================================================================
freqs_cos, freqs_sin = self.get_rotary_pos_embed(
target_video_length, target_height, target_width
)
n_tokens = freqs_cos.shape[0]
# ========================================================================
# Print infer args
# ========================================================================
debug_str = f"""
height: {target_height}
width: {target_width}
video_length: {target_video_length}
prompt: {prompt}
neg_prompt: {negative_prompt}
seed: {seed}
infer_steps: {infer_steps}
num_videos_per_prompt: {num_videos_per_prompt}
guidance_scale: {guidance_scale}
n_tokens: {n_tokens}
flow_shift: {flow_shift}
embedded_guidance_scale: {embedded_guidance_scale}"""
logger.debug(debug_str)
# ========================================================================
# Pipeline inference
# ========================================================================
start_time = time.time()
samples = self.pipeline(
prompt=prompt,
height=target_height,
width=target_width,
video_length=target_video_length,
num_inference_steps=infer_steps,
guidance_scale=guidance_scale,
negative_prompt=negative_prompt,
num_videos_per_prompt=num_videos_per_prompt,
generator=generator,
output_type="pil",
freqs_cis=(freqs_cos, freqs_sin),
n_tokens=n_tokens,
embedded_guidance_scale=embedded_guidance_scale,
data_type="video" if target_video_length > 1 else "image",
is_progress_bar=True,
vae_ver=self.args.vae,
enable_tiling=self.args.vae_tiling,
)[0]
out_dict["samples"] = samples
out_dict["prompts"] = prompt
gen_time = time.time() - start_time
logger.info(f"Success, time: {gen_time}")
return out_dict
from .models import HYVideoDiffusionTransformer, HUNYUAN_VIDEO_CONFIG
def load_model(args, in_channels, out_channels, factor_kwargs):
"""load hunyuan video model
Args:
args (dict): model args
in_channels (int): input channels number
out_channels (int): output channels number
factor_kwargs (dict): factor kwargs
Returns:
model (nn.Module): The hunyuan video model
"""
if args.model in HUNYUAN_VIDEO_CONFIG.keys():
model = HYVideoDiffusionTransformer(
args,
in_channels=in_channels,
out_channels=out_channels,
**HUNYUAN_VIDEO_CONFIG[args.model],
**factor_kwargs,
)
return model
else:
raise NotImplementedError()
import torch.nn as nn
def get_activation_layer(act_type):
"""get activation layer
Args:
act_type (str): the activation type
Returns:
torch.nn.functional: the activation layer
"""
if act_type == "gelu":
return lambda: nn.GELU()
elif act_type == "gelu_tanh":
# Approximate `tanh` requires torch >= 1.13
return lambda: nn.GELU(approximate="tanh")
elif act_type == "relu":
return nn.ReLU
elif act_type == "silu":
return nn.SiLU
else:
raise ValueError(f"Unknown activation type: {act_type}")
import importlib.metadata
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
try:
import flash_attn
from flash_attn.flash_attn_interface import _flash_attn_forward
from flash_attn.flash_attn_interface import flash_attn_varlen_func
except ImportError:
flash_attn = None
flash_attn_varlen_func = None
_flash_attn_forward = None
MEMORY_LAYOUT = {
"flash": (
lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
lambda x: x,
),
"torch": (
lambda x: x.transpose(1, 2),
lambda x: x.transpose(1, 2),
),
"vanilla": (
lambda x: x.transpose(1, 2),
lambda x: x.transpose(1, 2),
),
}
def get_cu_seqlens(text_mask, img_len):
"""Calculate cu_seqlens_q, cu_seqlens_kv using text_mask and img_len
Args:
text_mask (torch.Tensor): the mask of text
img_len (int): the length of image
Returns:
torch.Tensor: the calculated cu_seqlens for flash attention
"""
batch_size = text_mask.shape[0]
text_len = text_mask.sum(dim=1)
max_len = text_mask.shape[1] + img_len
cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda")
for i in range(batch_size):
s = text_len[i] + img_len
s1 = i * max_len + s
s2 = (i + 1) * max_len
cu_seqlens[2 * i + 1] = s1
cu_seqlens[2 * i + 2] = s2
return cu_seqlens
def attention(
q,
k,
v,
mode="flash",
drop_rate=0,
attn_mask=None,
causal=False,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
batch_size=1,
):
"""
Perform QKV self attention.
Args:
q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
k (torch.Tensor): Key tensor with shape [b, s1, a, d]
v (torch.Tensor): Value tensor with shape [b, s1, a, d]
mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
drop_rate (float): Dropout rate in attention map. (default: 0)
attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
(default: None)
causal (bool): Whether to use causal attention. (default: False)
cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
used to index into q.
cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
used to index into kv.
max_seqlen_q (int): The maximum sequence length in the batch of q.
max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
Returns:
torch.Tensor: Output tensor after self attention with shape [b, s, ad]
"""
pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
q = pre_attn_layout(q)
k = pre_attn_layout(k)
v = pre_attn_layout(v)
if mode == "torch":
if attn_mask is not None and attn_mask.dtype != torch.bool:
attn_mask = attn_mask.to(q.dtype)
if cu_seqlens_q is None:
x = F.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal
)
else:
attn1 = F.scaled_dot_product_attention(
q[:, :, :cu_seqlens_q[1]],
k[:, :, :cu_seqlens_kv[1]],
v[:, :, :cu_seqlens_kv[1]],
attn_mask=attn_mask,
dropout_p=drop_rate,
is_causal=causal
)
attn2 = F.scaled_dot_product_attention(
q[:, :, cu_seqlens_q[1]:],
k[:, :, cu_seqlens_kv[1]:],
v[:, :, cu_seqlens_kv[1]:],
attn_mask=None,
dropout_p=drop_rate,
is_causal=False
)
x = torch.cat([attn1, attn2], dim=2)
elif mode == "flash":
x = flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
)
# x with shape [(bxs), a, d]
x = x.view(
batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]
) # reshape x to [b, s, a, d]
elif mode == "vanilla":
scale_factor = 1 / math.sqrt(q.size(-1))
b, a, s, _ = q.shape
s1 = k.size(2)
attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
if causal:
# Only applied to self attention
assert (
attn_mask is None
), "Causal mask and attn_mask cannot be used together"
temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(
diagonal=0
)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(q.dtype)
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias += attn_mask
# TODO: Maybe force q and k to be float32 to avoid numerical overflow
attn = (q @ k.transpose(-2, -1)) * scale_factor
attn += attn_bias
attn = attn.softmax(dim=-1)
attn = torch.dropout(attn, p=drop_rate, train=True)
x = attn @ v
else:
raise NotImplementedError(f"Unsupported attention mode: {mode}")
x = post_attn_layout(x)
b, s, a, d = x.shape
out = x.reshape(b, s, -1)
return out
def parallel_attention(
hybrid_seq_parallel_attn,
q,
k,
v,
img_q_len,
img_kv_len,
cu_seqlens_q,
cu_seqlens_kv
):
attn1 = hybrid_seq_parallel_attn(
None,
q[:, :img_q_len, :, :],
k[:, :img_kv_len, :, :],
v[:, :img_kv_len, :, :],
dropout_p=0.0,
causal=False,
joint_tensor_query=q[:,img_q_len:cu_seqlens_q[1]],
joint_tensor_key=k[:,img_kv_len:cu_seqlens_kv[1]],
joint_tensor_value=v[:,img_kv_len:cu_seqlens_kv[1]],
joint_strategy="rear",
)
if flash_attn.__version__ >= '2.7.0':
attn2, *_ = _flash_attn_forward(
q[:,cu_seqlens_q[1]:],
k[:,cu_seqlens_kv[1]:],
v[:,cu_seqlens_kv[1]:],
dropout_p=0.0,
softmax_scale=q.shape[-1] ** (-0.5),
causal=False,
window_size_left=-1,
window_size_right=-1,
softcap=0.0,
alibi_slopes=None,
return_softmax=False,
)
else:
attn2, *_ = _flash_attn_forward(
q[:,cu_seqlens_q[1]:],
k[:,cu_seqlens_kv[1]:],
v[:,cu_seqlens_kv[1]:],
dropout_p=0.0,
softmax_scale=q.shape[-1] ** (-0.5),
causal=False,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
return_softmax=False,
)
attn = torch.cat([attn1, attn2], dim=1)
b, s, a, d = attn.shape
attn = attn.reshape(b, s, -1)
return attn
import math
import torch
import torch.nn as nn
from einops import rearrange, repeat
from ..utils.helpers import to_2tuple
class PatchEmbed(nn.Module):
"""2D Image to Patch Embedding
Image to Patch Embedding using Conv2d
A convolution based approach to patchifying a 2D image w/ embedding projection.
Based on the impl in https://github.com/google-research/vision_transformer
Hacked together by / Copyright 2020 Ross Wightman
Remove the _assert function in forward function to be compatible with multi-resolution images.
"""
def __init__(
self,
patch_size=16,
in_chans=3,
embed_dim=768,
norm_layer=None,
flatten=True,
bias=True,
dtype=None,
device=None,
):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
patch_size = to_2tuple(patch_size)
self.patch_size = patch_size
self.flatten = flatten
self.proj = nn.Conv3d(
in_chans,
embed_dim,
kernel_size=patch_size,
stride=patch_size,
bias=bias,
**factory_kwargs
)
nn.init.xavier_uniform_(self.proj.weight.view(self.proj.weight.size(0), -1))
if bias:
nn.init.zeros_(self.proj.bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
x = self.norm(x)
return x
class TextProjection(nn.Module):
"""
Projects text embeddings. Also handles dropout for classifier-free guidance.
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
"""
def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
self.linear_1 = nn.Linear(
in_features=in_channels,
out_features=hidden_size,
bias=True,
**factory_kwargs
)
self.act_1 = act_layer()
self.linear_2 = nn.Linear(
in_features=hidden_size,
out_features=hidden_size,
bias=True,
**factory_kwargs
)
def forward(self, caption):
hidden_states = self.linear_1(caption)
hidden_states = self.act_1(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
Args:
t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
dim (int): the dimension of the output.
max_period (int): controls the minimum frequency of the embeddings.
Returns:
embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
.. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
"""
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half
).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(
self,
hidden_size,
act_layer,
frequency_embedding_size=256,
max_period=10000,
out_size=None,
dtype=None,
device=None,
):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
self.frequency_embedding_size = frequency_embedding_size
self.max_period = max_period
if out_size is None:
out_size = hidden_size
self.mlp = nn.Sequential(
nn.Linear(
frequency_embedding_size, hidden_size, bias=True, **factory_kwargs
),
act_layer(),
nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
)
nn.init.normal_(self.mlp[0].weight, std=0.02)
nn.init.normal_(self.mlp[2].weight, std=0.02)
def forward(self, t):
t_freq = timestep_embedding(
t, self.frequency_embedding_size, self.max_period
).type(self.mlp[0].weight.dtype)
t_emb = self.mlp(t_freq)
return t_emb
import os
import torch
import torch.nn as nn
from torch.nn import functional as F
def get_fp_maxval(bits=8, mantissa_bit=3, sign_bits=1):
_bits = torch.tensor(bits)
_mantissa_bit = torch.tensor(mantissa_bit)
_sign_bits = torch.tensor(sign_bits)
M = torch.clamp(torch.round(_mantissa_bit), 1, _bits - _sign_bits)
E = _bits - _sign_bits - M
bias = 2 ** (E - 1) - 1
mantissa = 1
for i in range(mantissa_bit - 1):
mantissa += 1 / (2 ** (i+1))
maxval = mantissa * 2 ** (2**E - 1 - bias)
return maxval
def quantize_to_fp8(x, bits=8, mantissa_bit=3, sign_bits=1):
"""
Default is E4M3.
"""
bits = torch.tensor(bits)
mantissa_bit = torch.tensor(mantissa_bit)
sign_bits = torch.tensor(sign_bits)
M = torch.clamp(torch.round(mantissa_bit), 1, bits - sign_bits)
E = bits - sign_bits - M
bias = 2 ** (E - 1) - 1
mantissa = 1
for i in range(mantissa_bit - 1):
mantissa += 1 / (2 ** (i+1))
maxval = mantissa * 2 ** (2**E - 1 - bias)
minval = - maxval
minval = - maxval if sign_bits == 1 else torch.zeros_like(maxval)
input_clamp = torch.min(torch.max(x, minval), maxval)
log_scales = torch.clamp((torch.floor(torch.log2(torch.abs(input_clamp)) + bias)).detach(), 1.0)
log_scales = 2.0 ** (log_scales - M - bias.type(x.dtype))
# dequant
qdq_out = torch.round(input_clamp / log_scales) * log_scales
return qdq_out, log_scales
def fp8_tensor_quant(x, scale, bits=8, mantissa_bit=3, sign_bits=1):
for i in range(len(x.shape) - 1):
scale = scale.unsqueeze(-1)
new_x = x / scale
quant_dequant_x, log_scales = quantize_to_fp8(new_x, bits=bits, mantissa_bit=mantissa_bit, sign_bits=sign_bits)
return quant_dequant_x, scale, log_scales
def fp8_activation_dequant(qdq_out, scale, dtype):
qdq_out = qdq_out.type(dtype)
quant_dequant_x = qdq_out * scale.to(dtype)
return quant_dequant_x
def fp8_linear_forward(cls, original_dtype, input):
weight_dtype = cls.weight.dtype
#####
if cls.weight.dtype != torch.float8_e4m3fn:
maxval = get_fp_maxval()
scale = torch.max(torch.abs(cls.weight.flatten())) / maxval
linear_weight, scale, log_scales = fp8_tensor_quant(cls.weight, scale)
linear_weight = linear_weight.to(torch.float8_e4m3fn)
weight_dtype = linear_weight.dtype
else:
scale = cls.fp8_scale.to(cls.weight.device)
linear_weight = cls.weight
#####
if weight_dtype == torch.float8_e4m3fn and cls.weight.sum() != 0:
if True or len(input.shape) == 3:
cls_dequant = fp8_activation_dequant(linear_weight, scale, original_dtype)
if cls.bias != None:
output = F.linear(input, cls_dequant, cls.bias)
else:
output = F.linear(input, cls_dequant)
return output
else:
return cls.original_forward(input.to(original_dtype))
else:
return cls.original_forward(input)
def convert_fp8_linear(module, dit_weight_path, original_dtype, params_to_keep={}):
setattr(module, "fp8_matmul_enabled", True)
# loading fp8 mapping file
fp8_map_path = dit_weight_path.replace('.pt', '_map.pt')
if os.path.exists(fp8_map_path):
fp8_map = torch.load(fp8_map_path, map_location=lambda storage, loc: storage)
else:
raise ValueError(f"Invalid fp8_map path: {fp8_map_path}.")
fp8_layers = []
for key, layer in module.named_modules():
if isinstance(layer, nn.Linear) and ('double_blocks' in key or 'single_blocks' in key):
fp8_layers.append(key)
original_forward = layer.forward
layer.weight = torch.nn.Parameter(layer.weight.to(torch.float8_e4m3fn))
setattr(layer, "fp8_scale", fp8_map[key].to(dtype=original_dtype))
setattr(layer, "original_forward", original_forward)
setattr(layer, "forward", lambda input, m=layer: fp8_linear_forward(m, original_dtype, input))
# Modified from timm library:
# https://github.com/huggingface/pytorch-image-models/blob/648aaa41233ba83eb38faf5ba9d415d574823241/timm/layers/mlp.py#L13
from functools import partial
import torch
import torch.nn as nn
from .modulate_layers import modulate
from ..utils.helpers import to_2tuple
class MLP(nn.Module):
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
def __init__(
self,
in_channels,
hidden_channels=None,
out_features=None,
act_layer=nn.GELU,
norm_layer=None,
bias=True,
drop=0.0,
use_conv=False,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
out_features = out_features or in_channels
hidden_channels = hidden_channels or in_channels
bias = to_2tuple(bias)
drop_probs = to_2tuple(drop)
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
self.fc1 = linear_layer(
in_channels, hidden_channels, bias=bias[0], **factory_kwargs
)
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.norm = (
norm_layer(hidden_channels, **factory_kwargs)
if norm_layer is not None
else nn.Identity()
)
self.fc2 = linear_layer(
hidden_channels, out_features, bias=bias[1], **factory_kwargs
)
self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.norm(x)
x = self.fc2(x)
x = self.drop2(x)
return x
#
class MLPEmbedder(nn.Module):
"""copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py"""
def __init__(self, in_dim: int, hidden_dim: int, device=None, dtype=None):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, **factory_kwargs)
self.silu = nn.SiLU()
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, **factory_kwargs)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.out_layer(self.silu(self.in_layer(x)))
class FinalLayer(nn.Module):
"""The final layer of DiT."""
def __init__(
self, hidden_size, patch_size, out_channels, act_layer, device=None, dtype=None
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
# Just use LayerNorm for the final layer
self.norm_final = nn.LayerNorm(
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
)
if isinstance(patch_size, int):
self.linear = nn.Linear(
hidden_size,
patch_size * patch_size * out_channels,
bias=True,
**factory_kwargs
)
else:
self.linear = nn.Linear(
hidden_size,
patch_size[0] * patch_size[1] * patch_size[2] * out_channels,
bias=True,
)
nn.init.zeros_(self.linear.weight)
nn.init.zeros_(self.linear.bias)
# Here we don't distinguish between the modulate types. Just use the simple one.
self.adaLN_modulation = nn.Sequential(
act_layer(),
nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
)
# Zero-initialize the modulation
nn.init.zeros_(self.adaLN_modulation[1].weight)
nn.init.zeros_(self.adaLN_modulation[1].bias)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift=shift, scale=scale)
x = self.linear(x)
return x
from typing import Any, List, Tuple, Optional, Union, Dict
from einops import rearrange
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models import ModelMixin
from diffusers.configuration_utils import ConfigMixin, register_to_config
from .activation_layers import get_activation_layer
from .norm_layers import get_norm_layer
from .embed_layers import TimestepEmbedder, PatchEmbed, TextProjection
from .attenion import attention, parallel_attention, get_cu_seqlens
from .posemb_layers import apply_rotary_emb
from .mlp_layers import MLP, MLPEmbedder, FinalLayer
from .modulate_layers import ModulateDiT, modulate, apply_gate
from .token_refiner import SingleTokenRefiner
class MMDoubleStreamBlock(nn.Module):
"""
A multimodal dit block with seperate modulation for
text and image/video, see more details (SD3): https://arxiv.org/abs/2403.03206
(Flux.1): https://github.com/black-forest-labs/flux
"""
def __init__(
self,
hidden_size: int,
heads_num: int,
mlp_width_ratio: float,
mlp_act_type: str = "gelu_tanh",
qk_norm: bool = True,
qk_norm_type: str = "rms",
qkv_bias: bool = False,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.deterministic = False
self.heads_num = heads_num
head_dim = hidden_size // heads_num
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
self.img_mod = ModulateDiT(
hidden_size,
factor=6,
act_layer=get_activation_layer("silu"),
**factory_kwargs,
)
self.img_norm1 = nn.LayerNorm(
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
)
self.img_attn_qkv = nn.Linear(
hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
)
qk_norm_layer = get_norm_layer(qk_norm_type)
self.img_attn_q_norm = (
qk_norm_layer(head_dim, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.img_attn_k_norm = (
qk_norm_layer(head_dim, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.img_attn_proj = nn.Linear(
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
)
self.img_norm2 = nn.LayerNorm(
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
)
self.img_mlp = MLP(
hidden_size,
mlp_hidden_dim,
act_layer=get_activation_layer(mlp_act_type),
bias=True,
**factory_kwargs,
)
self.txt_mod = ModulateDiT(
hidden_size,
factor=6,
act_layer=get_activation_layer("silu"),
**factory_kwargs,
)
self.txt_norm1 = nn.LayerNorm(
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
)
self.txt_attn_qkv = nn.Linear(
hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
)
self.txt_attn_q_norm = (
qk_norm_layer(head_dim, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.txt_attn_k_norm = (
qk_norm_layer(head_dim, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.txt_attn_proj = nn.Linear(
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
)
self.txt_norm2 = nn.LayerNorm(
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
)
self.txt_mlp = MLP(
hidden_size,
mlp_hidden_dim,
act_layer=get_activation_layer(mlp_act_type),
bias=True,
**factory_kwargs,
)
self.hybrid_seq_parallel_attn = None
def enable_deterministic(self):
self.deterministic = True
def disable_deterministic(self):
self.deterministic = False
def forward(
self,
img: torch.Tensor,
txt: torch.Tensor,
vec: torch.Tensor,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None,
freqs_cis: tuple = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
(
img_mod1_shift,
img_mod1_scale,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
) = self.img_mod(vec).chunk(6, dim=-1)
(
txt_mod1_shift,
txt_mod1_scale,
txt_mod1_gate,
txt_mod2_shift,
txt_mod2_scale,
txt_mod2_gate,
) = self.txt_mod(vec).chunk(6, dim=-1)
# Prepare image for attention.
img_modulated = self.img_norm1(img)
img_modulated = modulate(
img_modulated, shift=img_mod1_shift, scale=img_mod1_scale
)
img_qkv = self.img_attn_qkv(img_modulated)
img_q, img_k, img_v = rearrange(
img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num
)
# Apply QK-Norm if needed
img_q = self.img_attn_q_norm(img_q).to(img_v)
img_k = self.img_attn_k_norm(img_k).to(img_v)
# Apply RoPE if needed.
if freqs_cis is not None:
img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
assert (
img_qq.shape == img_q.shape and img_kk.shape == img_k.shape
), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}"
img_q, img_k = img_qq, img_kk
# Prepare txt for attention.
txt_modulated = self.txt_norm1(txt)
txt_modulated = modulate(
txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale
)
txt_qkv = self.txt_attn_qkv(txt_modulated)
txt_q, txt_k, txt_v = rearrange(
txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num
)
# Apply QK-Norm if needed.
txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
# Run actual attention.
q = torch.cat((img_q, txt_q), dim=1)
k = torch.cat((img_k, txt_k), dim=1)
v = torch.cat((img_v, txt_v), dim=1)
assert (
cu_seqlens_q.shape[0] == 2 * img.shape[0] + 1
), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, img.shape[0]:{img.shape[0]}"
# attention computation start
if not self.hybrid_seq_parallel_attn:
attn = attention(
q,
k,
v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
batch_size=img_k.shape[0],
)
else:
attn = parallel_attention(
self.hybrid_seq_parallel_attn,
q,
k,
v,
img_q_len=img_q.shape[1],
img_kv_len=img_k.shape[1],
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv
)
# attention computation end
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :]
# Calculate the img bloks.
img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate)
img = img + apply_gate(
self.img_mlp(
modulate(
self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale
)
),
gate=img_mod2_gate,
)
# Calculate the txt bloks.
txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate)
txt = txt + apply_gate(
self.txt_mlp(
modulate(
self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale
)
),
gate=txt_mod2_gate,
)
return img, txt
class MMSingleStreamBlock(nn.Module):
"""
A DiT block with parallel linear layers as described in
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
Also refer to (SD3): https://arxiv.org/abs/2403.03206
(Flux.1): https://github.com/black-forest-labs/flux
"""
def __init__(
self,
hidden_size: int,
heads_num: int,
mlp_width_ratio: float = 4.0,
mlp_act_type: str = "gelu_tanh",
qk_norm: bool = True,
qk_norm_type: str = "rms",
qk_scale: float = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.deterministic = False
self.hidden_size = hidden_size
self.heads_num = heads_num
head_dim = hidden_size // heads_num
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
self.mlp_hidden_dim = mlp_hidden_dim
self.scale = qk_scale or head_dim ** -0.5
# qkv and mlp_in
self.linear1 = nn.Linear(
hidden_size, hidden_size * 3 + mlp_hidden_dim, **factory_kwargs
)
# proj and mlp_out
self.linear2 = nn.Linear(
hidden_size + mlp_hidden_dim, hidden_size, **factory_kwargs
)
qk_norm_layer = get_norm_layer(qk_norm_type)
self.q_norm = (
qk_norm_layer(head_dim, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.k_norm = (
qk_norm_layer(head_dim, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.pre_norm = nn.LayerNorm(
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
)
self.mlp_act = get_activation_layer(mlp_act_type)()
self.modulation = ModulateDiT(
hidden_size,
factor=3,
act_layer=get_activation_layer("silu"),
**factory_kwargs,
)
self.hybrid_seq_parallel_attn = None
def enable_deterministic(self):
self.deterministic = True
def disable_deterministic(self):
self.deterministic = False
def forward(
self,
x: torch.Tensor,
vec: torch.Tensor,
txt_len: int,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None,
freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
) -> torch.Tensor:
mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)
x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale)
qkv, mlp = torch.split(
self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1
)
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
# Apply QK-Norm if needed.
q = self.q_norm(q).to(v)
k = self.k_norm(k).to(v)
# Apply RoPE if needed.
if freqs_cis is not None:
img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
assert (
img_qq.shape == img_q.shape and img_kk.shape == img_k.shape
), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}"
img_q, img_k = img_qq, img_kk
q = torch.cat((img_q, txt_q), dim=1)
k = torch.cat((img_k, txt_k), dim=1)
# Compute attention.
assert (
cu_seqlens_q.shape[0] == 2 * x.shape[0] + 1
), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, x.shape[0]:{x.shape[0]}"
# attention computation start
if not self.hybrid_seq_parallel_attn:
attn = attention(
q,
k,
v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
batch_size=x.shape[0],
)
else:
attn = parallel_attention(
self.hybrid_seq_parallel_attn,
q,
k,
v,
img_q_len=img_q.shape[1],
img_kv_len=img_k.shape[1],
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv
)
# attention computation end
# Compute activation in mlp stream, cat again and run second linear layer.
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
return x + apply_gate(output, gate=mod_gate)
class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
"""
HunyuanVideo Transformer backbone
Inherited from ModelMixin and ConfigMixin for compatibility with diffusers' sampler StableDiffusionPipeline.
Reference:
[1] Flux.1: https://github.com/black-forest-labs/flux
[2] MMDiT: http://arxiv.org/abs/2403.03206
Parameters
----------
args: argparse.Namespace
The arguments parsed by argparse.
patch_size: list
The size of the patch.
in_channels: int
The number of input channels.
out_channels: int
The number of output channels.
hidden_size: int
The hidden size of the transformer backbone.
heads_num: int
The number of attention heads.
mlp_width_ratio: float
The ratio of the hidden size of the MLP in the transformer block.
mlp_act_type: str
The activation function of the MLP in the transformer block.
depth_double_blocks: int
The number of transformer blocks in the double blocks.
depth_single_blocks: int
The number of transformer blocks in the single blocks.
rope_dim_list: list
The dimension of the rotary embedding for t, h, w.
qkv_bias: bool
Whether to use bias in the qkv linear layer.
qk_norm: bool
Whether to use qk norm.
qk_norm_type: str
The type of qk norm.
guidance_embed: bool
Whether to use guidance embedding for distillation.
text_projection: str
The type of the text projection, default is single_refiner.
use_attention_mask: bool
Whether to use attention mask for text encoder.
dtype: torch.dtype
The dtype of the model.
device: torch.device
The device of the model.
"""
@register_to_config
def __init__(
self,
args: Any,
patch_size: list = [1, 2, 2],
in_channels: int = 4, # Should be VAE.config.latent_channels.
out_channels: int = None,
hidden_size: int = 3072,
heads_num: int = 24,
mlp_width_ratio: float = 4.0,
mlp_act_type: str = "gelu_tanh",
mm_double_blocks_depth: int = 20,
mm_single_blocks_depth: int = 40,
rope_dim_list: List[int] = [16, 56, 56],
qkv_bias: bool = True,
qk_norm: bool = True,
qk_norm_type: str = "rms",
guidance_embed: bool = False, # For modulation.
text_projection: str = "single_refiner",
use_attention_mask: bool = True,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.patch_size = patch_size
self.in_channels = in_channels
self.out_channels = in_channels if out_channels is None else out_channels
self.unpatchify_channels = self.out_channels
self.guidance_embed = guidance_embed
self.rope_dim_list = rope_dim_list
# Text projection. Default to linear projection.
# Alternative: TokenRefiner. See more details (LI-DiT): http://arxiv.org/abs/2406.11831
self.use_attention_mask = use_attention_mask
self.text_projection = text_projection
self.text_states_dim = args.text_states_dim
self.text_states_dim_2 = args.text_states_dim_2
if hidden_size % heads_num != 0:
raise ValueError(
f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}"
)
pe_dim = hidden_size // heads_num
if sum(rope_dim_list) != pe_dim:
raise ValueError(
f"Got {rope_dim_list} but expected positional dim {pe_dim}"
)
self.hidden_size = hidden_size
self.heads_num = heads_num
# image projection
self.img_in = PatchEmbed(
self.patch_size, self.in_channels, self.hidden_size, **factory_kwargs
)
# text projection
if self.text_projection == "linear":
self.txt_in = TextProjection(
self.text_states_dim,
self.hidden_size,
get_activation_layer("silu"),
**factory_kwargs,
)
elif self.text_projection == "single_refiner":
self.txt_in = SingleTokenRefiner(
self.text_states_dim, hidden_size, heads_num, depth=2, **factory_kwargs
)
else:
raise NotImplementedError(
f"Unsupported text_projection: {self.text_projection}"
)
# time modulation
self.time_in = TimestepEmbedder(
self.hidden_size, get_activation_layer("silu"), **factory_kwargs
)
# text modulation
self.vector_in = MLPEmbedder(
self.text_states_dim_2, self.hidden_size, **factory_kwargs
)
# guidance modulation
self.guidance_in = (
TimestepEmbedder(
self.hidden_size, get_activation_layer("silu"), **factory_kwargs
)
if guidance_embed
else None
)
# double blocks
self.double_blocks = nn.ModuleList(
[
MMDoubleStreamBlock(
self.hidden_size,
self.heads_num,
mlp_width_ratio=mlp_width_ratio,
mlp_act_type=mlp_act_type,
qk_norm=qk_norm,
qk_norm_type=qk_norm_type,
qkv_bias=qkv_bias,
**factory_kwargs,
)
for _ in range(mm_double_blocks_depth)
]
)
# single blocks
self.single_blocks = nn.ModuleList(
[
MMSingleStreamBlock(
self.hidden_size,
self.heads_num,
mlp_width_ratio=mlp_width_ratio,
mlp_act_type=mlp_act_type,
qk_norm=qk_norm,
qk_norm_type=qk_norm_type,
**factory_kwargs,
)
for _ in range(mm_single_blocks_depth)
]
)
self.final_layer = FinalLayer(
self.hidden_size,
self.patch_size,
self.out_channels,
get_activation_layer("silu"),
**factory_kwargs,
)
def enable_deterministic(self):
for block in self.double_blocks:
block.enable_deterministic()
for block in self.single_blocks:
block.enable_deterministic()
def disable_deterministic(self):
for block in self.double_blocks:
block.disable_deterministic()
for block in self.single_blocks:
block.disable_deterministic()
def forward(
self,
x: torch.Tensor,
t: torch.Tensor, # Should be in range(0, 1000).
text_states: torch.Tensor = None,
text_mask: torch.Tensor = None, # Now we don't use it.
text_states_2: Optional[torch.Tensor] = None, # Text embedding for modulation.
freqs_cos: Optional[torch.Tensor] = None,
freqs_sin: Optional[torch.Tensor] = None,
guidance: torch.Tensor = None, # Guidance for modulation, should be cfg_scale x 1000.
return_dict: bool = True,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
out = {}
img = x
txt = text_states
_, _, ot, oh, ow = x.shape
tt, th, tw = (
ot // self.patch_size[0],
oh // self.patch_size[1],
ow // self.patch_size[2],
)
# Prepare modulation vectors.
vec = self.time_in(t)
# text modulation
vec = vec + self.vector_in(text_states_2)
# guidance modulation
if self.guidance_embed:
if guidance is None:
raise ValueError(
"Didn't get guidance strength for guidance distilled model."
)
# our timestep_embedding is merged into guidance_in(TimestepEmbedder)
vec = vec + self.guidance_in(guidance)
# Embed image and text.
img = self.img_in(img)
if self.text_projection == "linear":
txt = self.txt_in(txt)
elif self.text_projection == "single_refiner":
txt = self.txt_in(txt, t, text_mask if self.use_attention_mask else None)
else:
raise NotImplementedError(
f"Unsupported text_projection: {self.text_projection}"
)
txt_seq_len = txt.shape[1]
img_seq_len = img.shape[1]
# Compute cu_squlens and max_seqlen for flash attention
cu_seqlens_q = get_cu_seqlens(text_mask, img_seq_len)
cu_seqlens_kv = cu_seqlens_q
max_seqlen_q = img_seq_len + txt_seq_len
max_seqlen_kv = max_seqlen_q
freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
# --------------------- Pass through DiT blocks ------------------------
for _, block in enumerate(self.double_blocks):
double_block_args = [
img,
txt,
vec,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
freqs_cis,
]
img, txt = block(*double_block_args)
# Merge txt and img to pass through single stream blocks.
x = torch.cat((img, txt), 1)
if len(self.single_blocks) > 0:
for _, block in enumerate(self.single_blocks):
single_block_args = [
x,
vec,
txt_seq_len,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
(freqs_cos, freqs_sin),
]
x = block(*single_block_args)
img = x[:, :img_seq_len, ...]
# ---------------------------- Final layer ------------------------------
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
img = self.unpatchify(img, tt, th, tw)
if return_dict:
out["x"] = img
return out
return img
def unpatchify(self, x, t, h, w):
"""
x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C)
"""
c = self.unpatchify_channels
pt, ph, pw = self.patch_size
assert t * h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw))
x = torch.einsum("nthwcopq->nctohpwq", x)
imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
return imgs
def params_count(self):
counts = {
"double": sum(
[
sum(p.numel() for p in block.img_attn_qkv.parameters())
+ sum(p.numel() for p in block.img_attn_proj.parameters())
+ sum(p.numel() for p in block.img_mlp.parameters())
+ sum(p.numel() for p in block.txt_attn_qkv.parameters())
+ sum(p.numel() for p in block.txt_attn_proj.parameters())
+ sum(p.numel() for p in block.txt_mlp.parameters())
for block in self.double_blocks
]
),
"single": sum(
[
sum(p.numel() for p in block.linear1.parameters())
+ sum(p.numel() for p in block.linear2.parameters())
for block in self.single_blocks
]
),
"total": sum(p.numel() for p in self.parameters()),
}
counts["attn+mlp"] = counts["double"] + counts["single"]
return counts
#################################################################################
# HunyuanVideo Configs #
#################################################################################
HUNYUAN_VIDEO_CONFIG = {
"HYVideo-T/2": {
"mm_double_blocks_depth": 20,
"mm_single_blocks_depth": 40,
"rope_dim_list": [16, 56, 56],
"hidden_size": 3072,
"heads_num": 24,
"mlp_width_ratio": 4,
},
"HYVideo-T/2-cfgdistill": {
"mm_double_blocks_depth": 20,
"mm_single_blocks_depth": 40,
"rope_dim_list": [16, 56, 56],
"hidden_size": 3072,
"heads_num": 24,
"mlp_width_ratio": 4,
"guidance_embed": True,
},
}
from typing import Any, List, Tuple, Optional, Union, Dict
from einops import rearrange
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models import ModelMixin
from diffusers.configuration_utils import ConfigMixin, register_to_config
from .activation_layers import get_activation_layer
from .norm_layers import get_norm_layer
from .embed_layers import TimestepEmbedder, PatchEmbed, TextProjection
from .attenion import attention, parallel_attention, get_cu_seqlens
from .posemb_layers import apply_rotary_emb
from .mlp_layers import MLP, MLPEmbedder, FinalLayer
from .modulate_layers import ModulateDiT, modulate, apply_gate
from .token_refiner import SingleTokenRefiner
class MMDoubleStreamBlock(nn.Module):
"""
A multimodal dit block with seperate modulation for
text and image/video, see more details (SD3): https://arxiv.org/abs/2403.03206
(Flux.1): https://github.com/black-forest-labs/flux
"""
def __init__(
self,
hidden_size: int,
heads_num: int,
mlp_width_ratio: float,
mlp_act_type: str = "gelu_tanh",
qk_norm: bool = True,
qk_norm_type: str = "rms",
qkv_bias: bool = False,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.deterministic = False
self.heads_num = heads_num
head_dim = hidden_size // heads_num
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
self.img_mod = ModulateDiT(
hidden_size,
factor=6,
act_layer=get_activation_layer("silu"),
**factory_kwargs,
)
self.img_norm1 = nn.LayerNorm(
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
)
self.img_attn_qkv = nn.Linear(
hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
)
qk_norm_layer = get_norm_layer(qk_norm_type)
self.img_attn_q_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.img_attn_k_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.img_attn_proj = nn.Linear(
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
)
self.img_norm2 = nn.LayerNorm(
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
)
self.img_mlp = MLP(
hidden_size,
mlp_hidden_dim,
act_layer=get_activation_layer(mlp_act_type),
bias=True,
**factory_kwargs,
)
self.txt_mod = ModulateDiT(
hidden_size,
factor=6,
act_layer=get_activation_layer("silu"),
**factory_kwargs,
)
self.txt_norm1 = nn.LayerNorm(
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
)
self.txt_attn_qkv = nn.Linear(
hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
)
self.txt_attn_q_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.txt_attn_k_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.txt_attn_proj = nn.Linear(
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
)
self.txt_norm2 = nn.LayerNorm(
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
)
self.txt_mlp = MLP(
hidden_size,
mlp_hidden_dim,
act_layer=get_activation_layer(mlp_act_type),
bias=True,
**factory_kwargs,
)
self.hybrid_seq_parallel_attn = None
def enable_deterministic(self):
self.deterministic = True
def disable_deterministic(self):
self.deterministic = False
def forward(
self,
img: torch.Tensor,
txt: torch.Tensor,
vec: torch.Tensor,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None,
freqs_cis: tuple = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
(
img_mod1_shift,
img_mod1_scale,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
) = self.img_mod(vec).chunk(6, dim=-1)
(
txt_mod1_shift,
txt_mod1_scale,
txt_mod1_gate,
txt_mod2_shift,
txt_mod2_scale,
txt_mod2_gate,
) = self.txt_mod(vec).chunk(6, dim=-1)
# Prepare image for attention.
img_modulated = self.img_norm1(img)
img_modulated = modulate(
img_modulated, shift=img_mod1_shift, scale=img_mod1_scale
)
img_qkv = self.img_attn_qkv(img_modulated)
img_q, img_k, img_v = rearrange(
img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num
)
# Apply QK-Norm if needed
img_q = self.img_attn_q_norm(img_q).to(img_v)
img_k = self.img_attn_k_norm(img_k).to(img_v)
# Apply RoPE if needed.
if freqs_cis is not None:
img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
assert (
img_qq.shape == img_q.shape and img_kk.shape == img_k.shape
), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}"
img_q, img_k = img_qq, img_kk
# Prepare txt for attention.
txt_modulated = self.txt_norm1(txt)
txt_modulated = modulate(
txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale
)
txt_qkv = self.txt_attn_qkv(txt_modulated)
txt_q, txt_k, txt_v = rearrange(
txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num
)
# Apply QK-Norm if needed.
txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
# Run actual attention.
q = torch.cat((img_q, txt_q), dim=1)
k = torch.cat((img_k, txt_k), dim=1)
v = torch.cat((img_v, txt_v), dim=1)
assert (
cu_seqlens_q.shape[0] == 2 * img.shape[0] + 1
), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, img.shape[0]:{img.shape[0]}"
# attention computation start
if not self.hybrid_seq_parallel_attn:
attn = attention(
q,
k,
v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
batch_size=img_k.shape[0],
)
else:
attn = parallel_attention(
self.hybrid_seq_parallel_attn,
q,
k,
v,
img_q_len=img_q.shape[1],
img_kv_len=img_k.shape[1],
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv
)
# attention computation end
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :]
# Calculate the img bloks.
img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate)
img = img + apply_gate(
self.img_mlp(
modulate(
self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale
)
),
gate=img_mod2_gate,
)
# Calculate the txt bloks.
txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate)
txt = txt + apply_gate(
self.txt_mlp(
modulate(
self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale
)
),
gate=txt_mod2_gate,
)
return img, txt
class MMSingleStreamBlock(nn.Module):
"""
A DiT block with parallel linear layers as described in
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
Also refer to (SD3): https://arxiv.org/abs/2403.03206
(Flux.1): https://github.com/black-forest-labs/flux
"""
def __init__(
self,
hidden_size: int,
heads_num: int,
mlp_width_ratio: float = 4.0,
mlp_act_type: str = "gelu_tanh",
qk_norm: bool = True,
qk_norm_type: str = "rms",
qk_scale: float = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.deterministic = False
self.hidden_size = hidden_size
self.heads_num = heads_num
head_dim = hidden_size // heads_num
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
self.mlp_hidden_dim = mlp_hidden_dim
self.scale = qk_scale or head_dim ** -0.5
# qkv and mlp_in
self.linear1 = nn.Linear(
hidden_size, hidden_size * 3 + mlp_hidden_dim, **factory_kwargs
)
# proj and mlp_out
self.linear2 = nn.Linear(
hidden_size + mlp_hidden_dim, hidden_size, **factory_kwargs
)
qk_norm_layer = get_norm_layer(qk_norm_type)
self.q_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.k_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.pre_norm = nn.LayerNorm(
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
)
self.mlp_act = get_activation_layer(mlp_act_type)()
self.modulation = ModulateDiT(
hidden_size,
factor=3,
act_layer=get_activation_layer("silu"),
**factory_kwargs,
)
self.hybrid_seq_parallel_attn = None
def enable_deterministic(self):
self.deterministic = True
def disable_deterministic(self):
self.deterministic = False
def forward(
self,
x: torch.Tensor,
vec: torch.Tensor,
txt_len: int,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None,
freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
) -> torch.Tensor:
mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)
x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale)
qkv, mlp = torch.split(
self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1
)
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
# Apply QK-Norm if needed.
q = self.q_norm(q).to(v)
k = self.k_norm(k).to(v)
# Apply RoPE if needed.
if freqs_cis is not None:
img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
assert (
img_qq.shape == img_q.shape and img_kk.shape == img_k.shape
), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}"
img_q, img_k = img_qq, img_kk
q = torch.cat((img_q, txt_q), dim=1)
k = torch.cat((img_k, txt_k), dim=1)
# Compute attention.
assert (
cu_seqlens_q.shape[0] == 2 * x.shape[0] + 1
), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, x.shape[0]:{x.shape[0]}"
# attention computation start
if not self.hybrid_seq_parallel_attn:
attn = attention(
q,
k,
v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
batch_size=x.shape[0],
)
else:
attn = parallel_attention(
self.hybrid_seq_parallel_attn,
q,
k,
v,
img_q_len=img_q.shape[1],
img_kv_len=img_k.shape[1],
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv
)
# attention computation end
# Compute activation in mlp stream, cat again and run second linear layer.
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
return x + apply_gate(output, gate=mod_gate)
class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
"""
HunyuanVideo Transformer backbone
Inherited from ModelMixin and ConfigMixin for compatibility with diffusers' sampler StableDiffusionPipeline.
Reference:
[1] Flux.1: https://github.com/black-forest-labs/flux
[2] MMDiT: http://arxiv.org/abs/2403.03206
Parameters
----------
args: argparse.Namespace
The arguments parsed by argparse.
patch_size: list
The size of the patch.
in_channels: int
The number of input channels.
out_channels: int
The number of output channels.
hidden_size: int
The hidden size of the transformer backbone.
heads_num: int
The number of attention heads.
mlp_width_ratio: float
The ratio of the hidden size of the MLP in the transformer block.
mlp_act_type: str
The activation function of the MLP in the transformer block.
depth_double_blocks: int
The number of transformer blocks in the double blocks.
depth_single_blocks: int
The number of transformer blocks in the single blocks.
rope_dim_list: list
The dimension of the rotary embedding for t, h, w.
qkv_bias: bool
Whether to use bias in the qkv linear layer.
qk_norm: bool
Whether to use qk norm.
qk_norm_type: str
The type of qk norm.
guidance_embed: bool
Whether to use guidance embedding for distillation.
text_projection: str
The type of the text projection, default is single_refiner.
use_attention_mask: bool
Whether to use attention mask for text encoder.
dtype: torch.dtype
The dtype of the model.
device: torch.device
The device of the model.
"""
@register_to_config
def __init__(
self,
args: Any,
patch_size: list = [1, 2, 2],
in_channels: int = 4, # Should be VAE.config.latent_channels.
out_channels: int = None,
hidden_size: int = 3072,
heads_num: int = 24,
mlp_width_ratio: float = 4.0,
mlp_act_type: str = "gelu_tanh",
mm_double_blocks_depth: int = 20,
mm_single_blocks_depth: int = 40,
rope_dim_list: List[int] = [16, 56, 56],
qkv_bias: bool = True,
qk_norm: bool = True,
qk_norm_type: str = "rms",
guidance_embed: bool = False, # For modulation.
text_projection: str = "single_refiner",
use_attention_mask: bool = True,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.patch_size = patch_size
self.in_channels = in_channels
self.out_channels = in_channels if out_channels is None else out_channels
self.unpatchify_channels = self.out_channels
self.guidance_embed = guidance_embed
self.rope_dim_list = rope_dim_list
# Text projection. Default to linear projection.
# Alternative: TokenRefiner. See more details (LI-DiT): http://arxiv.org/abs/2406.11831
self.use_attention_mask = use_attention_mask
self.text_projection = text_projection
self.text_states_dim = args.text_states_dim
self.text_states_dim_2 = args.text_states_dim_2
if hidden_size % heads_num != 0:
raise ValueError(
f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}"
)
pe_dim = hidden_size // heads_num
if sum(rope_dim_list) != pe_dim:
raise ValueError(
f"Got {rope_dim_list} but expected positional dim {pe_dim}"
)
self.hidden_size = hidden_size
self.heads_num = heads_num
# image projection
self.img_in = PatchEmbed(
self.patch_size, self.in_channels, self.hidden_size, **factory_kwargs
)
# text projection
if self.text_projection == "linear":
self.txt_in = TextProjection(
self.text_states_dim,
self.hidden_size,
get_activation_layer("silu"),
**factory_kwargs,
)
elif self.text_projection == "single_refiner":
self.txt_in = SingleTokenRefiner(
self.text_states_dim, hidden_size, heads_num, depth=2, **factory_kwargs
)
else:
raise NotImplementedError(
f"Unsupported text_projection: {self.text_projection}"
)
# time modulation
self.time_in = TimestepEmbedder(
self.hidden_size, get_activation_layer("silu"), **factory_kwargs
)
# text modulation
self.vector_in = MLPEmbedder(
self.text_states_dim_2, self.hidden_size, **factory_kwargs
)
# guidance modulation
self.guidance_in = (
TimestepEmbedder(
self.hidden_size, get_activation_layer("silu"), **factory_kwargs
)
if guidance_embed
else None
)
# double blocks
self.double_blocks = nn.ModuleList(
[
MMDoubleStreamBlock(
self.hidden_size,
self.heads_num,
mlp_width_ratio=mlp_width_ratio,
mlp_act_type=mlp_act_type,
qk_norm=qk_norm,
qk_norm_type=qk_norm_type,
qkv_bias=qkv_bias,
**factory_kwargs,
)
for _ in range(mm_double_blocks_depth)
]
)
# single blocks
self.single_blocks = nn.ModuleList(
[
MMSingleStreamBlock(
self.hidden_size,
self.heads_num,
mlp_width_ratio=mlp_width_ratio,
mlp_act_type=mlp_act_type,
qk_norm=qk_norm,
qk_norm_type=qk_norm_type,
**factory_kwargs,
)
for _ in range(mm_single_blocks_depth)
]
)
self.final_layer = FinalLayer(
self.hidden_size,
self.patch_size,
self.out_channels,
get_activation_layer("silu"),
**factory_kwargs,
)
def enable_deterministic(self):
for block in self.double_blocks:
block.enable_deterministic()
for block in self.single_blocks:
block.enable_deterministic()
def disable_deterministic(self):
for block in self.double_blocks:
block.disable_deterministic()
for block in self.single_blocks:
block.disable_deterministic()
def forward(
self,
x: torch.Tensor,
t: torch.Tensor, # Should be in range(0, 1000).
text_states: torch.Tensor = None,
text_mask: torch.Tensor = None, # Now we don't use it.
text_states_2: Optional[torch.Tensor] = None, # Text embedding for modulation.
freqs_cos: Optional[torch.Tensor] = None,
freqs_sin: Optional[torch.Tensor] = None,
guidance: torch.Tensor = None, # Guidance for modulation, should be cfg_scale x 1000.
return_dict: bool = True,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
out = {}
img = x
txt = text_states
_, _, ot, oh, ow = x.shape
tt, th, tw = (
ot // self.patch_size[0],
oh // self.patch_size[1],
ow // self.patch_size[2],
)
# Prepare modulation vectors.
vec = self.time_in(t)
# text modulation
vec = vec + self.vector_in(text_states_2)
# guidance modulation
if self.guidance_embed:
if guidance is None:
raise ValueError(
"Didn't get guidance strength for guidance distilled model."
)
# our timestep_embedding is merged into guidance_in(TimestepEmbedder)
vec = vec + self.guidance_in(guidance)
# Embed image and text.
img = self.img_in(img)
if self.text_projection == "linear":
txt = self.txt_in(txt)
elif self.text_projection == "single_refiner":
txt = self.txt_in(txt, t, text_mask if self.use_attention_mask else None)
else:
raise NotImplementedError(
f"Unsupported text_projection: {self.text_projection}"
)
txt_seq_len = txt.shape[1]
img_seq_len = img.shape[1]
# Compute cu_squlens and max_seqlen for flash attention
cu_seqlens_q = get_cu_seqlens(text_mask, img_seq_len)
cu_seqlens_kv = cu_seqlens_q
max_seqlen_q = img_seq_len + txt_seq_len
max_seqlen_kv = max_seqlen_q
freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
# --------------------- Pass through DiT blocks ------------------------
for _, block in enumerate(self.double_blocks):
double_block_args = [
img,
txt,
vec,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
freqs_cis,
]
img, txt = block(*double_block_args)
# Merge txt and img to pass through single stream blocks.
x = torch.cat((img, txt), 1)
if len(self.single_blocks) > 0:
for _, block in enumerate(self.single_blocks):
single_block_args = [
x,
vec,
txt_seq_len,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
(freqs_cos, freqs_sin),
]
x = block(*single_block_args)
img = x[:, :img_seq_len, ...]
# ---------------------------- Final layer ------------------------------
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
img = self.unpatchify(img, tt, th, tw)
if return_dict:
out["x"] = img
return out
return img
def unpatchify(self, x, t, h, w):
"""
x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C)
"""
c = self.unpatchify_channels
pt, ph, pw = self.patch_size
assert t * h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw))
x = torch.einsum("nthwcopq->nctohpwq", x)
imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
return imgs
def params_count(self):
counts = {
"double": sum(
[
sum(p.numel() for p in block.img_attn_qkv.parameters())
+ sum(p.numel() for p in block.img_attn_proj.parameters())
+ sum(p.numel() for p in block.img_mlp.parameters())
+ sum(p.numel() for p in block.txt_attn_qkv.parameters())
+ sum(p.numel() for p in block.txt_attn_proj.parameters())
+ sum(p.numel() for p in block.txt_mlp.parameters())
for block in self.double_blocks
]
),
"single": sum(
[
sum(p.numel() for p in block.linear1.parameters())
+ sum(p.numel() for p in block.linear2.parameters())
for block in self.single_blocks
]
),
"total": sum(p.numel() for p in self.parameters()),
}
counts["attn+mlp"] = counts["double"] + counts["single"]
return counts
#################################################################################
# HunyuanVideo Configs #
#################################################################################
HUNYUAN_VIDEO_CONFIG = {
"HYVideo-T/2": {
"mm_double_blocks_depth": 20,
"mm_single_blocks_depth": 40,
"rope_dim_list": [16, 56, 56],
"hidden_size": 3072,
"heads_num": 24,
"mlp_width_ratio": 4,
},
"HYVideo-T/2-cfgdistill": {
"mm_double_blocks_depth": 20,
"mm_single_blocks_depth": 40,
"rope_dim_list": [16, 56, 56],
"hidden_size": 3072,
"heads_num": 24,
"mlp_width_ratio": 4,
"guidance_embed": True,
},
}
from typing import Any, List, Tuple, Optional, Union, Dict
from einops import rearrange
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models import ModelMixin
from diffusers.configuration_utils import ConfigMixin, register_to_config
from .activation_layers import get_activation_layer
from .norm_layers import get_norm_layer
from .embed_layers import TimestepEmbedder, PatchEmbed, TextProjection
from .attenion import attention, parallel_attention, get_cu_seqlens
from .posemb_layers import apply_rotary_emb
from .mlp_layers import MLP, MLPEmbedder, FinalLayer
from .modulate_layers import ModulateDiT, modulate, apply_gate
from .token_refiner import SingleTokenRefiner
class MMDoubleStreamBlock(nn.Module):
"""
A multimodal dit block with seperate modulation for
text and image/video, see more details (SD3): https://arxiv.org/abs/2403.03206
(Flux.1): https://github.com/black-forest-labs/flux
"""
def __init__(
self,
hidden_size: int,
heads_num: int,
mlp_width_ratio: float,
mlp_act_type: str = "gelu_tanh",
qk_norm: bool = True,
qk_norm_type: str = "rms",
qkv_bias: bool = False,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.deterministic = False
self.heads_num = heads_num
head_dim = hidden_size // heads_num
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
self.img_mod = ModulateDiT(
hidden_size,
factor=6,
act_layer=get_activation_layer("silu"),
**factory_kwargs,
)
self.img_norm1 = nn.LayerNorm(
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
)
self.img_attn_qkv = nn.Linear(
hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
)
qk_norm_layer = get_norm_layer(qk_norm_type)
self.img_attn_q_norm = (
qk_norm_layer(head_dim, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.img_attn_k_norm = (
qk_norm_layer(head_dim, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.img_attn_proj = nn.Linear(
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
)
self.img_norm2 = nn.LayerNorm(
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
)
self.img_mlp = MLP(
hidden_size,
mlp_hidden_dim,
act_layer=get_activation_layer(mlp_act_type),
bias=True,
**factory_kwargs,
)
self.txt_mod = ModulateDiT(
hidden_size,
factor=6,
act_layer=get_activation_layer("silu"),
**factory_kwargs,
)
self.txt_norm1 = nn.LayerNorm(
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
)
self.txt_attn_qkv = nn.Linear(
hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
)
self.txt_attn_q_norm = (
qk_norm_layer(head_dim, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.txt_attn_k_norm = (
qk_norm_layer(head_dim, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.txt_attn_proj = nn.Linear(
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
)
self.txt_norm2 = nn.LayerNorm(
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
)
self.txt_mlp = MLP(
hidden_size,
mlp_hidden_dim,
act_layer=get_activation_layer(mlp_act_type),
bias=True,
**factory_kwargs,
)
self.hybrid_seq_parallel_attn = None
def enable_deterministic(self):
self.deterministic = True
def disable_deterministic(self):
self.deterministic = False
def forward(
self,
img: torch.Tensor,
txt: torch.Tensor,
vec: torch.Tensor,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None,
freqs_cis: tuple = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
(
img_mod1_shift,
img_mod1_scale,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
) = self.img_mod(vec).chunk(6, dim=-1)
(
txt_mod1_shift,
txt_mod1_scale,
txt_mod1_gate,
txt_mod2_shift,
txt_mod2_scale,
txt_mod2_gate,
) = self.txt_mod(vec).chunk(6, dim=-1)
# Prepare image for attention.
img_modulated = self.img_norm1(img)
img_modulated = modulate(
img_modulated, shift=img_mod1_shift, scale=img_mod1_scale
)
img_qkv = self.img_attn_qkv(img_modulated)
img_q, img_k, img_v = rearrange(
img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num
)
# Apply QK-Norm if needed
img_q = self.img_attn_q_norm(img_q).to(img_v)
img_k = self.img_attn_k_norm(img_k).to(img_v)
# Apply RoPE if needed.
if freqs_cis is not None:
img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
assert (
img_qq.shape == img_q.shape and img_kk.shape == img_k.shape
), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}"
img_q, img_k = img_qq, img_kk
# Prepare txt for attention.
txt_modulated = self.txt_norm1(txt)
txt_modulated = modulate(
txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale
)
txt_qkv = self.txt_attn_qkv(txt_modulated)
txt_q, txt_k, txt_v = rearrange(
txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num
)
# Apply QK-Norm if needed.
txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
# Run actual attention.
q = torch.cat((img_q, txt_q), dim=1)
k = torch.cat((img_k, txt_k), dim=1)
v = torch.cat((img_v, txt_v), dim=1)
assert (
cu_seqlens_q.shape[0] == 2 * img.shape[0] + 1
), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, img.shape[0]:{img.shape[0]}"
# attention computation start
if not self.hybrid_seq_parallel_attn:
attn = attention(
q,
k,
v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
batch_size=img_k.shape[0],
)
else:
attn = parallel_attention(
self.hybrid_seq_parallel_attn,
q,
k,
v,
img_q_len=img_q.shape[1],
img_kv_len=img_k.shape[1],
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv
)
# attention computation end
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :]
# Calculate the img bloks.
img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate)
img = img + apply_gate(
self.img_mlp(
modulate(
self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale
)
),
gate=img_mod2_gate,
)
# Calculate the txt bloks.
txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate)
txt = txt + apply_gate(
self.txt_mlp(
modulate(
self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale
)
),
gate=txt_mod2_gate,
)
return img, txt
class MMSingleStreamBlock(nn.Module):
"""
A DiT block with parallel linear layers as described in
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
Also refer to (SD3): https://arxiv.org/abs/2403.03206
(Flux.1): https://github.com/black-forest-labs/flux
"""
def __init__(
self,
hidden_size: int,
heads_num: int,
mlp_width_ratio: float = 4.0,
mlp_act_type: str = "gelu_tanh",
qk_norm: bool = True,
qk_norm_type: str = "rms",
qk_scale: float = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.deterministic = False
self.hidden_size = hidden_size
self.heads_num = heads_num
head_dim = hidden_size // heads_num
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
self.mlp_hidden_dim = mlp_hidden_dim
self.scale = qk_scale or head_dim ** -0.5
# qkv and mlp_in
self.linear1 = nn.Linear(
hidden_size, hidden_size * 3 + mlp_hidden_dim, **factory_kwargs
)
# proj and mlp_out
self.linear2 = nn.Linear(
hidden_size + mlp_hidden_dim, hidden_size, **factory_kwargs
)
qk_norm_layer = get_norm_layer(qk_norm_type)
self.q_norm = (
qk_norm_layer(head_dim, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.k_norm = (
qk_norm_layer(head_dim, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.pre_norm = nn.LayerNorm(
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
)
self.mlp_act = get_activation_layer(mlp_act_type)()
self.modulation = ModulateDiT(
hidden_size,
factor=3,
act_layer=get_activation_layer("silu"),
**factory_kwargs,
)
self.hybrid_seq_parallel_attn = None
def enable_deterministic(self):
self.deterministic = True
def disable_deterministic(self):
self.deterministic = False
def forward(
self,
x: torch.Tensor,
vec: torch.Tensor,
txt_len: int,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None,
freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
) -> torch.Tensor:
mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)
x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale)
qkv, mlp = torch.split(
self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1
)
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
# Apply QK-Norm if needed.
q = self.q_norm(q).to(v)
k = self.k_norm(k).to(v)
# Apply RoPE if needed.
if freqs_cis is not None:
img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
assert (
img_qq.shape == img_q.shape and img_kk.shape == img_k.shape
), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}"
img_q, img_k = img_qq, img_kk
q = torch.cat((img_q, txt_q), dim=1)
k = torch.cat((img_k, txt_k), dim=1)
# Compute attention.
assert (
cu_seqlens_q.shape[0] == 2 * x.shape[0] + 1
), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, x.shape[0]:{x.shape[0]}"
# attention computation start
if not self.hybrid_seq_parallel_attn:
attn = attention(
q,
k,
v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
batch_size=x.shape[0],
)
else:
attn = parallel_attention(
self.hybrid_seq_parallel_attn,
q,
k,
v,
img_q_len=img_q.shape[1],
img_kv_len=img_k.shape[1],
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv
)
# attention computation end
# Compute activation in mlp stream, cat again and run second linear layer.
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
return x + apply_gate(output, gate=mod_gate)
class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
"""
HunyuanVideo Transformer backbone
Inherited from ModelMixin and ConfigMixin for compatibility with diffusers' sampler StableDiffusionPipeline.
Reference:
[1] Flux.1: https://github.com/black-forest-labs/flux
[2] MMDiT: http://arxiv.org/abs/2403.03206
Parameters
----------
args: argparse.Namespace
The arguments parsed by argparse.
patch_size: list
The size of the patch.
in_channels: int
The number of input channels.
out_channels: int
The number of output channels.
hidden_size: int
The hidden size of the transformer backbone.
heads_num: int
The number of attention heads.
mlp_width_ratio: float
The ratio of the hidden size of the MLP in the transformer block.
mlp_act_type: str
The activation function of the MLP in the transformer block.
depth_double_blocks: int
The number of transformer blocks in the double blocks.
depth_single_blocks: int
The number of transformer blocks in the single blocks.
rope_dim_list: list
The dimension of the rotary embedding for t, h, w.
qkv_bias: bool
Whether to use bias in the qkv linear layer.
qk_norm: bool
Whether to use qk norm.
qk_norm_type: str
The type of qk norm.
guidance_embed: bool
Whether to use guidance embedding for distillation.
text_projection: str
The type of the text projection, default is single_refiner.
use_attention_mask: bool
Whether to use attention mask for text encoder.
dtype: torch.dtype
The dtype of the model.
device: torch.device
The device of the model.
"""
@register_to_config
def __init__(
self,
args: Any,
patch_size: list = [1, 2, 2],
in_channels: int = 4, # Should be VAE.config.latent_channels.
out_channels: int = None,
hidden_size: int = 3072,
heads_num: int = 24,
mlp_width_ratio: float = 4.0,
mlp_act_type: str = "gelu_tanh",
mm_double_blocks_depth: int = 20,
mm_single_blocks_depth: int = 40,
rope_dim_list: List[int] = [16, 56, 56],
qkv_bias: bool = True,
qk_norm: bool = True,
qk_norm_type: str = "rms",
guidance_embed: bool = False, # For modulation.
text_projection: str = "single_refiner",
use_attention_mask: bool = True,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.patch_size = patch_size
self.in_channels = in_channels
self.out_channels = in_channels if out_channels is None else out_channels
self.unpatchify_channels = self.out_channels
self.guidance_embed = guidance_embed
self.rope_dim_list = rope_dim_list
# Text projection. Default to linear projection.
# Alternative: TokenRefiner. See more details (LI-DiT): http://arxiv.org/abs/2406.11831
self.use_attention_mask = use_attention_mask
self.text_projection = text_projection
self.text_states_dim = args.text_states_dim
self.text_states_dim_2 = args.text_states_dim_2
if hidden_size % heads_num != 0:
raise ValueError(
f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}"
)
pe_dim = hidden_size // heads_num
if sum(rope_dim_list) != pe_dim:
raise ValueError(
f"Got {rope_dim_list} but expected positional dim {pe_dim}"
)
self.hidden_size = hidden_size
self.heads_num = heads_num
# image projection
self.img_in = PatchEmbed(
self.patch_size, self.in_channels, self.hidden_size, **factory_kwargs
)
# text projection
if self.text_projection == "linear":
self.txt_in = TextProjection(
self.text_states_dim,
self.hidden_size,
get_activation_layer("silu"),
**factory_kwargs,
)
elif self.text_projection == "single_refiner":
self.txt_in = SingleTokenRefiner(
self.text_states_dim, hidden_size, heads_num, depth=2, **factory_kwargs
)
else:
raise NotImplementedError(
f"Unsupported text_projection: {self.text_projection}"
)
# time modulation
self.time_in = TimestepEmbedder(
self.hidden_size, get_activation_layer("silu"), **factory_kwargs
)
# text modulation
self.vector_in = MLPEmbedder(
self.text_states_dim_2, self.hidden_size, **factory_kwargs
)
# guidance modulation
self.guidance_in = (
TimestepEmbedder(
self.hidden_size, get_activation_layer("silu"), **factory_kwargs
)
if guidance_embed
else None
)
# double blocks
self.double_blocks = nn.ModuleList(
[
MMDoubleStreamBlock(
self.hidden_size,
self.heads_num,
mlp_width_ratio=mlp_width_ratio,
mlp_act_type=mlp_act_type,
qk_norm=qk_norm,
qk_norm_type=qk_norm_type,
qkv_bias=qkv_bias,
**factory_kwargs,
)
for _ in range(mm_double_blocks_depth)
]
)
# single blocks
self.single_blocks = nn.ModuleList(
[
MMSingleStreamBlock(
self.hidden_size,
self.heads_num,
mlp_width_ratio=mlp_width_ratio,
mlp_act_type=mlp_act_type,
qk_norm=qk_norm,
qk_norm_type=qk_norm_type,
**factory_kwargs,
)
for _ in range(mm_single_blocks_depth)
]
)
self.final_layer = FinalLayer(
self.hidden_size,
self.patch_size,
self.out_channels,
get_activation_layer("silu"),
**factory_kwargs,
)
def enable_deterministic(self):
for block in self.double_blocks:
block.enable_deterministic()
for block in self.single_blocks:
block.enable_deterministic()
def disable_deterministic(self):
for block in self.double_blocks:
block.disable_deterministic()
for block in self.single_blocks:
block.disable_deterministic()
def forward(
self,
x: torch.Tensor,
t: torch.Tensor, # Should be in range(0, 1000).
text_states: torch.Tensor = None,
text_mask: torch.Tensor = None, # Now we don't use it.
text_states_2: Optional[torch.Tensor] = None, # Text embedding for modulation.
freqs_cos: Optional[torch.Tensor] = None,
freqs_sin: Optional[torch.Tensor] = None,
guidance: torch.Tensor = None, # Guidance for modulation, should be cfg_scale x 1000.
return_dict: bool = True,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
out = {}
img = x
txt = text_states
_, _, ot, oh, ow = x.shape
tt, th, tw = (
ot // self.patch_size[0],
oh // self.patch_size[1],
ow // self.patch_size[2],
)
# Prepare modulation vectors.
vec = self.time_in(t)
# text modulation
vec = vec + self.vector_in(text_states_2)
# guidance modulation
if self.guidance_embed:
if guidance is None:
raise ValueError(
"Didn't get guidance strength for guidance distilled model."
)
# our timestep_embedding is merged into guidance_in(TimestepEmbedder)
vec = vec + self.guidance_in(guidance)
# Embed image and text.
img = self.img_in(img)
if self.text_projection == "linear":
txt = self.txt_in(txt)
elif self.text_projection == "single_refiner":
txt = self.txt_in(txt, t, text_mask if self.use_attention_mask else None)
else:
raise NotImplementedError(
f"Unsupported text_projection: {self.text_projection}"
)
txt_seq_len = txt.shape[1]
img_seq_len = img.shape[1]
# Compute cu_squlens and max_seqlen for flash attention
cu_seqlens_q = get_cu_seqlens(text_mask, img_seq_len)
cu_seqlens_kv = cu_seqlens_q
max_seqlen_q = img_seq_len + txt_seq_len
max_seqlen_kv = max_seqlen_q
freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
# --------------------- Pass through DiT blocks ------------------------
for _, block in enumerate(self.double_blocks):
double_block_args = [
img,
txt,
vec,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
freqs_cis,
]
img, txt = block(*double_block_args)
# Merge txt and img to pass through single stream blocks.
x = torch.cat((img, txt), 1)
if len(self.single_blocks) > 0:
for _, block in enumerate(self.single_blocks):
single_block_args = [
x,
vec,
txt_seq_len,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
(freqs_cos, freqs_sin),
]
x = block(*single_block_args)
img = x[:, :img_seq_len, ...]
# ---------------------------- Final layer ------------------------------
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
img = self.unpatchify(img, tt, th, tw)
if return_dict:
out["x"] = img
return out
return img
def unpatchify(self, x, t, h, w):
"""
x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C)
"""
c = self.unpatchify_channels
pt, ph, pw = self.patch_size
assert t * h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw))
x = torch.einsum("nthwcopq->nctohpwq", x)
imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
return imgs
def params_count(self):
counts = {
"double": sum(
[
sum(p.numel() for p in block.img_attn_qkv.parameters())
+ sum(p.numel() for p in block.img_attn_proj.parameters())
+ sum(p.numel() for p in block.img_mlp.parameters())
+ sum(p.numel() for p in block.txt_attn_qkv.parameters())
+ sum(p.numel() for p in block.txt_attn_proj.parameters())
+ sum(p.numel() for p in block.txt_mlp.parameters())
for block in self.double_blocks
]
),
"single": sum(
[
sum(p.numel() for p in block.linear1.parameters())
+ sum(p.numel() for p in block.linear2.parameters())
for block in self.single_blocks
]
),
"total": sum(p.numel() for p in self.parameters()),
}
counts["attn+mlp"] = counts["double"] + counts["single"]
return counts
#################################################################################
# HunyuanVideo Configs #
#################################################################################
HUNYUAN_VIDEO_CONFIG = {
"HYVideo-T/2": {
"mm_double_blocks_depth": 20,
"mm_single_blocks_depth": 40,
"rope_dim_list": [16, 56, 56],
"hidden_size": 3072,
"heads_num": 24,
"mlp_width_ratio": 4,
},
"HYVideo-T/2-cfgdistill": {
"mm_double_blocks_depth": 20,
"mm_single_blocks_depth": 40,
"rope_dim_list": [16, 56, 56],
"hidden_size": 3072,
"heads_num": 24,
"mlp_width_ratio": 4,
"guidance_embed": True,
},
}
from typing import Callable
import torch
import torch.nn as nn
class ModulateDiT(nn.Module):
"""Modulation layer for DiT."""
def __init__(
self,
hidden_size: int,
factor: int,
act_layer: Callable,
dtype=None,
device=None,
):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
self.act = act_layer()
self.linear = nn.Linear(
hidden_size, factor * hidden_size, bias=True, **factory_kwargs
)
# Zero-initialize the modulation
nn.init.zeros_(self.linear.weight)
nn.init.zeros_(self.linear.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(self.act(x))
def modulate(x, shift=None, scale=None):
"""modulate by shift and scale
Args:
x (torch.Tensor): input tensor.
shift (torch.Tensor, optional): shift tensor. Defaults to None.
scale (torch.Tensor, optional): scale tensor. Defaults to None.
Returns:
torch.Tensor: the output tensor after modulate.
"""
if scale is None and shift is None:
return x
elif shift is None:
return x * (1 + scale.unsqueeze(1))
elif scale is None:
return x + shift.unsqueeze(1)
else:
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def apply_gate(x, gate=None, tanh=False):
"""AI is creating summary for apply_gate
Args:
x (torch.Tensor): input tensor.
gate (torch.Tensor, optional): gate tensor. Defaults to None.
tanh (bool, optional): whether to use tanh function. Defaults to False.
Returns:
torch.Tensor: the output tensor after apply gate.
"""
if gate is None:
return x
if tanh:
return x * gate.unsqueeze(1).tanh()
else:
return x * gate.unsqueeze(1)
def ckpt_wrapper(module):
def ckpt_forward(*inputs):
outputs = module(*inputs)
return outputs
return ckpt_forward
import torch
import torch.nn as nn
from lightop import RMSNorm
from lightop import LayerNorm
# class RMSNorm(nn.Module):
# def __init__(
# self,
# dim: int,
# elementwise_affine=True,
# eps: float = 1e-6,
# device=None,
# dtype=None,
# ):
# """
# Initialize the RMSNorm normalization layer.
# Args:
# dim (int): The dimension of the input tensor.
# eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
# Attributes:
# eps (float): A small value added to the denominator for numerical stability.
# weight (nn.Parameter): Learnable scaling parameter.
# """
# factory_kwargs = {"device": device, "dtype": dtype}
# super().__init__()
# self.eps = eps
# if elementwise_affine:
# self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
# def _norm(self, x):
# """
# Apply the RMSNorm normalization to the input tensor.
# Args:
# x (torch.Tensor): The input tensor.
# Returns:
# torch.Tensor: The normalized tensor.
# """
# return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
# def forward(self, x):
# """
# Forward pass through the RMSNorm layer.
# Args:
# x (torch.Tensor): The input tensor.
# Returns:
# torch.Tensor: The output tensor after applying RMSNorm.
# """
# output = self._norm(x.float()).type_as(x)
# if hasattr(self, "weight"):
# output = output * self.weight
# return output
def get_norm_layer(norm_layer):
"""
Get the normalization layer.
Args:
norm_layer (str): The type of normalization layer.
Returns:
norm_layer (nn.Module): The normalization layer.
"""
if norm_layer == "layer":
return LayerNorm
elif norm_layer == "rms":
return RMSNorm
else:
raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
import torch
import torch.nn as nn
class RMSNorm(nn.Module):
def __init__(
self,
dim: int,
elementwise_affine=True,
eps: float = 1e-6,
device=None,
dtype=None,
):
"""
Initialize the RMSNorm normalization layer.
Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.
"""
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.eps = eps
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
def _norm(self, x):
"""
Apply the RMSNorm normalization to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
"""
Forward pass through the RMSNorm layer.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying RMSNorm.
"""
output = self._norm(x.float()).type_as(x)
if hasattr(self, "weight"):
output = output * self.weight
return output
def get_norm_layer(norm_layer):
"""
Get the normalization layer.
Args:
norm_layer (str): The type of normalization layer.
Returns:
norm_layer (nn.Module): The normalization layer.
"""
if norm_layer == "layer":
return nn.LayerNorm
elif norm_layer == "rms":
return RMSNorm
else:
raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
import torch
import torch.nn as nn
from lightop import RMSNorm
from lightop import LayerNorm
# class RMSNorm(nn.Module):
# def __init__(
# self,
# dim: int,
# elementwise_affine=True,
# eps: float = 1e-6,
# device=None,
# dtype=None,
# ):
# """
# Initialize the RMSNorm normalization layer.
# Args:
# dim (int): The dimension of the input tensor.
# eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
# Attributes:
# eps (float): A small value added to the denominator for numerical stability.
# weight (nn.Parameter): Learnable scaling parameter.
# """
# factory_kwargs = {"device": device, "dtype": dtype}
# super().__init__()
# self.eps = eps
# if elementwise_affine:
# self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
# def _norm(self, x):
# """
# Apply the RMSNorm normalization to the input tensor.
# Args:
# x (torch.Tensor): The input tensor.
# Returns:
# torch.Tensor: The normalized tensor.
# """
# return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
# def forward(self, x):
# """
# Forward pass through the RMSNorm layer.
# Args:
# x (torch.Tensor): The input tensor.
# Returns:
# torch.Tensor: The output tensor after applying RMSNorm.
# """
# output = self._norm(x.float()).type_as(x)
# if hasattr(self, "weight"):
# output = output * self.weight
# return output
def get_norm_layer(norm_layer):
"""
Get the normalization layer.
Args:
norm_layer (str): The type of normalization layer.
Returns:
norm_layer (nn.Module): The normalization layer.
"""
if norm_layer == "layer":
return LayerNorm
elif norm_layer == "rms":
return RMSNorm
else:
raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment