Commit 0513d03d authored by jerrrrry's avatar jerrrrry
Browse files

Initial commit

parents
Pipeline #3321 canceled with stages
# This file implements USP with torch version < '2.5.0'
import torch
from torch.nn import functional as F
import torch.distributed._functional_collectives as ft_c
from yunchang.globals import PROCESS_GROUP
from yunchang.ring.ring_flash_attn import ring_flash_attn_forward
from xfuser.core.distributed import (
get_sequence_parallel_world_size,
get_ulysses_parallel_world_size,
get_ring_parallel_world_size,
)
def ring_attn(query, key, value, dropout_p=0.0, is_causal=False):
query = query.transpose(1,2).contiguous()
key = key.transpose(1,2).contiguous()
value = value.transpose(1,2).contiguous()
out, *_ = ring_flash_attn_forward(
PROCESS_GROUP.RING_PG,
query,
key,
value,
softmax_scale=query.shape[-1] ** (-0.5),
dropout_p=dropout_p,
causal=is_causal,
)
out = out.transpose(1,2).contiguous()
return out
def _maybe_wait(tensor: torch.Tensor) -> torch.Tensor:
"""
When tracing the code, the result tensor is not an AsyncCollectiveTensor,
so we cannot call ``wait()``.
"""
if isinstance(tensor, ft_c.AsyncCollectiveTensor):
return tensor.wait()
return tensor
def _sdpa_all_to_all_single(x):
x_shape = x.shape
x = x.flatten()
x = ft_c.all_to_all_single(x, output_split_sizes=None, input_split_sizes=None, group=PROCESS_GROUP.ULYSSES_PG)
x = _maybe_wait(x)
x = x.reshape(x_shape)
return x
def _ft_c_input_all_to_all(x):
world_size = get_ulysses_parallel_world_size()
if world_size <= 1:
return x
assert x.ndim == 4, "x must have 4 dimensions, got {}".format(x.ndim)
b, h, s, d = x.shape
assert h % world_size == 0, "h must be divisible by world_size, got {} and {}".format(h, world_size)
x = x.permute(1, 0, 2, 3).contiguous()
x = _sdpa_all_to_all_single(x)
x = x.reshape(world_size, h // world_size, b, -1, d).permute(2, 1, 0, 3, 4).reshape(b, h // world_size, -1, d)
return x
def _ft_c_output_all_to_all(x):
world_size = get_ulysses_parallel_world_size()
if world_size <= 1:
return x
assert x.ndim == 4, "x must have 4 dimensions, got {}".format(x.ndim)
b, h, s, d = x.shape
assert s % world_size == 0, "s must be divisible by world_size, got {} and {}".format(s, world_size)
x = x.permute(2, 0, 1, 3).contiguous()
x = _sdpa_all_to_all_single(x)
x = x.reshape(world_size, s // world_size, b, -1, d).permute(2, 0, 3, 1, 4).reshape(b, -1, s // world_size, d)
return x
@torch.compiler.disable
def USP(query, key, value, dropout_p=0.0, is_causal=False):
if get_sequence_parallel_world_size() == 1:
out = F.scaled_dot_product_attention(
query, key, value, dropout_p=dropout_p, is_causal=is_causal
)
elif get_ulysses_parallel_world_size() == 1:
out = ring_attn(query, key, value, dropout_p=dropout_p, is_causal=is_causal)
elif get_ulysses_parallel_world_size() > 1:
query = _ft_c_input_all_to_all(query)
key = _ft_c_input_all_to_all(key)
value = _ft_c_input_all_to_all(value)
if get_ring_parallel_world_size() == 1:
out = F.scaled_dot_product_attention(
query, key, value, dropout_p=dropout_p, is_causal=is_causal
)
else:
out = ring_attn(query, key, value, dropout_p=dropout_p, is_causal=is_causal)
out = _ft_c_output_all_to_all(out)
return out
from .base_model import xFuserModelBaseWrapper
__all__ = [
"xFuserModelBaseWrapper"
]
\ No newline at end of file
from abc import abstractmethod, ABCMeta
from typing import Dict, List, Optional, Type, Union
from functools import wraps
import torch.nn as nn
from xfuser.config import InputConfig, ParallelConfig, RuntimeConfig
from xfuser.core.cache_manager.cache_manager import get_cache_manager
from xfuser.core.distributed.parallel_state import get_sequence_parallel_world_size
from xfuser.model_executor.base_wrapper import xFuserBaseWrapper
from xfuser.model_executor.layers import *
from xfuser.core.distributed import get_world_group
from xfuser.logger import init_logger
logger = init_logger(__name__)
class xFuserModelBaseWrapper(nn.Module, xFuserBaseWrapper, metaclass=ABCMeta):
wrapped_layers: List[xFuserLayerBaseWrapper]
def __init__(self, module: nn.Module):
super().__init__()
super(nn.Module, self).__init__(
module=module,
)
def __getattr__(self, name: str):
if "_parameters" in self.__dict__:
_parameters = self.__dict__["_parameters"]
if name in _parameters:
return _parameters[name]
if "_buffers" in self.__dict__:
_buffers = self.__dict__["_buffers"]
if name in _buffers:
return _buffers[name]
if "_modules" in self.__dict__:
modules = self.__dict__["_modules"]
if name in modules:
return modules[name]
try:
return getattr(self.module, name)
except RecursionError:
raise AttributeError(
f"module {type(self.module).__name__} has no " f"attribute {name}"
)
def reset_activation_cache(self):
for layer in self.wrapped_layers:
if hasattr(layer, "activation_cache"):
layer.activation_cache = None
else:
logger.info(
f"layer {type(layer)} has no attribute "
f"activation_cache, do not need to reset"
)
def _wrap_layers(
self,
model: Optional[nn.Module] = None,
submodule_classes_to_wrap: List[Type] = [],
submodule_name_to_wrap: List[str] = [],
submodule_addition_args: Dict[str, Dict] = {},
) -> Union[nn.Module, None]:
wrapped_layers = []
wrap_self_module = False
if model is None:
wrap_self_module = True
model = self.module
for name, module in model.named_modules():
if isinstance(module, xFuserLayerBaseWrapper):
continue
for subname, submodule in module.named_children():
need_wrap = subname in submodule_name_to_wrap
for class_to_wrap in submodule_classes_to_wrap:
if isinstance(submodule, class_to_wrap):
need_wrap = True
break
if need_wrap:
wrapper = xFuserLayerWrappersRegister.get_wrapper(submodule)
additional_args = submodule_addition_args.get(subname, {})
logger.info(
f"[RANK {get_world_group().rank}] "
f"Wrapping {name}.{subname} in model class "
f"{model.__class__.__name__} with "
f"{wrapper.__name__}"
)
if additional_args is not {}:
if "temporal_transformer_blocks" in name and subname == "attn1":
setattr(
module,
subname,
wrapper(submodule, latte_temporal_attention=True),
)
else:
setattr(
module,
subname,
wrapper(
submodule,
**additional_args,
),
)
else:
setattr(
module,
subname,
wrapper(submodule),
)
# if isinstance(getattr(module, subname), xFuserPatchEmbedWrapper):
wrapped_layers.append(getattr(module, subname))
self.wrapped_layers = wrapped_layers
if wrap_self_module:
self.module = model
else:
return model
def _register_cache(
self,
):
for layer in self.wrapped_layers:
if isinstance(layer, xFuserAttentionWrapper):
# if getattr(layer.processor, 'use_long_ctx_attn_kvcache', False):
# TODO(Eigensystem): remove use_long_ctx_attn_kvcache flag
if get_sequence_parallel_world_size() == 1 or not getattr(
layer.processor, "use_long_ctx_attn_kvcache", False
):
get_cache_manager().register_cache_entry(
layer, layer_type="attn", cache_type="naive_cache"
)
else:
get_cache_manager().register_cache_entry(
layer,
layer_type="attn",
cache_type="sequence_parallel_attn_cache",
)
from .register import xFuserTransformerWrappersRegister
from .base_transformer import xFuserTransformerBaseWrapper
from .pixart_transformer_2d import xFuserPixArtTransformer2DWrapper
from .transformer_sd3 import xFuserSD3Transformer2DWrapper
from .transformer_flux import xFuserFluxTransformer2DWrapper
from .latte_transformer_3d import xFuserLatteTransformer3DWrapper
from .hunyuan_transformer_2d import xFuserHunyuanDiT2DWrapper
from .cogvideox_transformer_3d import xFuserCogVideoXTransformer3DWrapper
__all__ = [
"xFuserTransformerWrappersRegister",
"xFuserTransformerBaseWrapper",
"xFuserPixArtTransformer2DWrapper",
"xFuserSD3Transformer2DWrapper",
"xFuserFluxTransformer2DWrapper",
"xFuserLatteTransformer3DWrapper",
"xFuserCogVideoXTransformer3DWrapper",
"xFuserHunyuanDiT2DWrapper",
]
\ No newline at end of file
from abc import abstractmethod, ABCMeta
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Type
import torch
import torch.nn as nn
from xfuser.core.distributed import (
get_pipeline_parallel_rank,
is_pipeline_first_stage,
is_pipeline_last_stage,
get_pipeline_parallel_world_size,
get_sequence_parallel_world_size,
get_tensor_model_parallel_world_size,
)
from xfuser.core.fast_attention import get_fast_attn_enable
from xfuser.core.distributed.runtime_state import get_runtime_state
from xfuser.logger import init_logger
from xfuser.model_executor.models import xFuserModelBaseWrapper
logger = init_logger(__name__)
class StageInfo:
def __init__(self):
self.after_flags: Dict[str, bool] = {}
class xFuserTransformerBaseWrapper(xFuserModelBaseWrapper, metaclass=ABCMeta):
# transformer: original transformer model (for example Transformer2DModel)
def __init__(
self,
transformer: nn.Module,
submodule_classes_to_wrap: List[Type] = [],
submodule_name_to_wrap: List = [],
submodule_addition_args: Dict = {},
transformer_blocks_name: List[str] = ["transformer_blocks"],
):
self.stage_info = None
transformer = self._convert_transformer_for_parallel(
transformer,
submodule_classes_to_wrap=submodule_classes_to_wrap,
submodule_name_to_wrap=submodule_name_to_wrap,
submodule_addition_args=submodule_addition_args,
transformer_blocks_name=transformer_blocks_name,
)
super().__init__(module=transformer)
def _convert_transformer_for_parallel(
self,
transformer: nn.Module,
submodule_classes_to_wrap: List[Type] = [],
submodule_name_to_wrap: List = [],
submodule_addition_args: Dict = {},
transformer_blocks_name: List[str] = [],
) -> nn.Module:
if (
get_pipeline_parallel_world_size() == 1
and get_sequence_parallel_world_size() == 1
and get_tensor_model_parallel_world_size() == 1
and get_fast_attn_enable() == False
):
return transformer
else:
transformer = self._split_transformer_blocks(
transformer, transformer_blocks_name
)
transformer = self._wrap_layers(
model=transformer,
submodule_classes_to_wrap=submodule_classes_to_wrap,
submodule_name_to_wrap=submodule_name_to_wrap,
submodule_addition_args=submodule_addition_args,
)
self._register_cache()
return transformer
def _split_transformer_blocks(
self,
transformer: nn.Module,
blocks_name: List[str] = [],
):
for block_name in blocks_name:
if not hasattr(transformer, block_name):
raise AttributeError(
f"'{transformer.__class__.__name__}' object has no attribute "
f"'{block_name}'."
)
# transformer layer split
attn_layer_num_for_pp = (
get_runtime_state().parallel_config.pp_config.attn_layer_num_for_pp
)
pp_rank = get_pipeline_parallel_rank()
pp_world_size = get_pipeline_parallel_world_size()
blocks_list = {
block_name: getattr(transformer, block_name) for block_name in blocks_name
}
num_blocks_list = [len(blocks) for blocks in blocks_list.values()]
self.blocks_idx = {
name: [sum(num_blocks_list[:i]), sum(num_blocks_list[: i + 1])]
for i, name in enumerate(blocks_name)
}
if attn_layer_num_for_pp is not None:
assert sum(attn_layer_num_for_pp) == sum(num_blocks_list), (
"Sum of attn_layer_num_for_pp should be equal to the "
"number of all the transformer blocks"
)
stage_block_start_idx = sum(attn_layer_num_for_pp[:pp_rank])
stage_block_end_idx = sum(attn_layer_num_for_pp[: pp_rank + 1])
else:
num_blocks_per_stage = (
sum(num_blocks_list) + pp_world_size - 1
) // pp_world_size
stage_block_start_idx = pp_rank * num_blocks_per_stage
stage_block_end_idx = min(
(pp_rank + 1) * num_blocks_per_stage,
sum(num_blocks_list),
)
self.stage_info = StageInfo()
for name, [blocks_start, blocks_end] in zip(
self.blocks_idx.keys(), self.blocks_idx.values()
):
if (
blocks_end <= stage_block_start_idx
or stage_block_end_idx <= blocks_start
):
setattr(transformer, name, nn.ModuleList([]))
self.stage_info.after_flags[name] = False
elif stage_block_start_idx <= blocks_start:
if blocks_end <= stage_block_end_idx:
self.stage_info.after_flags[name] = True
else:
setattr(
transformer,
name,
blocks_list[name][: -(blocks_end - stage_block_end_idx)],
)
self.stage_info.after_flags[name] = False
elif blocks_start < stage_block_start_idx:
if blocks_end <= stage_block_end_idx:
setattr(
transformer,
name,
blocks_list[name][stage_block_start_idx - blocks_start :],
)
self.stage_info.after_flags[name] = True
else: # blocks_end > stage_layer_end_idx
setattr(
transformer,
name,
blocks_list[name][
stage_block_start_idx
- blocks_start : stage_block_end_idx
- blocks_end
],
)
self.stage_info.after_flags[name] = False
return transformer
@abstractmethod
def forward(self, *args, **kwargs):
pass
def _get_patch_height_width(self) -> Tuple[int, int]:
patch_size = get_runtime_state().backbone_patch_size
vae_scale_factor = get_runtime_state().vae_scale_factor
width = get_runtime_state().input_config.width // patch_size // vae_scale_factor
if get_runtime_state().patch_mode:
height = (
get_runtime_state().pp_patches_height[
get_runtime_state().pipeline_patch_idx
]
// patch_size
)
else:
height = sum(get_runtime_state().pp_patches_height) // patch_size
return height, width
from typing import Optional, Dict, Any, Union, List, Optional, Tuple, Type
import torch
import torch.distributed
import torch.nn as nn
from diffusers.models.embeddings import PatchEmbed, CogVideoXPatchEmbed
from diffusers.models import CogVideoXTransformer3DModel
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, scale_lora_layers, USE_PEFT_BACKEND, unscale_lora_layers
from xfuser.model_executor.models import xFuserModelBaseWrapper
from xfuser.logger import init_logger
from xfuser.model_executor.base_wrapper import xFuserBaseWrapper
from xfuser.core.distributed import (
get_data_parallel_world_size,
get_sequence_parallel_world_size,
get_pipeline_parallel_world_size,
get_classifier_free_guidance_world_size,
get_classifier_free_guidance_rank,
get_pipeline_parallel_rank,
get_pp_group,
get_world_group,
get_cfg_group,
get_sp_group,
get_runtime_state,
initialize_runtime_state
)
from xfuser.model_executor.models.transformers.register import xFuserTransformerWrappersRegister
from xfuser.model_executor.models.transformers.base_transformer import xFuserTransformerBaseWrapper
logger = init_logger(__name__)
@xFuserTransformerWrappersRegister.register(CogVideoXTransformer3DModel)
class xFuserCogVideoXTransformer3DWrapper(xFuserTransformerBaseWrapper):
def __init__(
self,
transformer: CogVideoXTransformer3DModel,
):
super().__init__(
transformer=transformer,
submodule_classes_to_wrap=[nn.Conv2d, CogVideoXPatchEmbed],
submodule_name_to_wrap=["attn1"]
)
@xFuserBaseWrapper.forward_check_condition
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: Union[int, float, torch.LongTensor],
timestep_cond: Optional[torch.Tensor] = None,
ofs: Optional[Union[int, float, torch.LongTensor]] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
):
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
batch_size, num_frames, channels, height, width = hidden_states.shape
# 1. Time embedding
timesteps = timestep
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=hidden_states.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
if self.ofs_embedding is not None:
ofs_emb = self.ofs_proj(ofs)
ofs_emb = ofs_emb.to(dtype=hidden_states.dtype)
ofs_emb = self.ofs_embedding(ofs_emb)
emb = emb + ofs_emb
# 2. Patch embedding
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
hidden_states = self.embedding_dropout(hidden_states)
text_seq_length = encoder_hidden_states.shape[1]
encoder_hidden_states = hidden_states[:, :text_seq_length]
hidden_states = hidden_states[:, text_seq_length:]
# 3. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
emb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=emb,
image_rotary_emb=image_rotary_emb,
)
if not self.config.use_rotary_positional_embeddings:
# CogVideoX-2B
hidden_states = self.norm_final(hidden_states)
else:
# CogVideoX-5B
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
hidden_states = self.norm_final(hidden_states)
hidden_states = hidden_states[:, text_seq_length:]
# 4. Final block
hidden_states = self.norm_out(hidden_states, temb=emb)
hidden_states = self.proj_out(hidden_states)
# 5. Unpatchify
p = self.config.patch_size
p_t = self.config.patch_size_t
if p_t is None:
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
else:
output = hidden_states.reshape(
batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
)
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
from typing import List, Optional, Dict, Any
import torch
import torch.distributed
import torch.nn as nn
from diffusers import HunyuanDiT2DModel
from diffusers.models.embeddings import PatchEmbed
from diffusers.models.attention_processor import Attention
from diffusers.models.transformers.transformer_2d import Transformer2DModelOutput
from diffusers.utils import is_torch_version
from xfuser.logger import init_logger
from xfuser.core.distributed.runtime_state import get_runtime_state
from xfuser.model_executor.base_wrapper import xFuserBaseWrapper
from xfuser.core.distributed import (
get_pipeline_parallel_rank,
get_pipeline_parallel_world_size,
is_pipeline_first_stage,
is_pipeline_last_stage,
)
from .register import xFuserTransformerWrappersRegister
from .base_transformer import xFuserTransformerBaseWrapper
logger = init_logger(__name__)
# adapted from
# https://github.com/huggingface/diffusers/blob/b5f591fea843cb4bf1932bd94d1db5d5eebe3298/src/diffusers/models/transformers/hunyuan_transformer_2d.py#L203
@xFuserTransformerWrappersRegister.register(HunyuanDiT2DModel)
class xFuserHunyuanDiT2DWrapper(xFuserTransformerBaseWrapper):
def __init__(
self,
transformer: HunyuanDiT2DModel,
):
super().__init__(
transformer=transformer,
submodule_classes_to_wrap=[nn.Conv2d, PatchEmbed],
submodule_name_to_wrap=["attn1"],
)
def _split_transformer_blocks(
self,
transformer: nn.Module,
blocks_name: List[str] = [],
):
if not hasattr(transformer, "blocks"):
raise AttributeError(
f"'{transformer.__class__.__name__}' object has no attribute 'blocks'"
)
# transformer layer split
attn_layer_num_for_pp = (
get_runtime_state().parallel_config.pp_config.attn_layer_num_for_pp
)
pp_rank = get_pipeline_parallel_rank()
pp_world_size = get_pipeline_parallel_world_size()
if attn_layer_num_for_pp is not None:
assert sum(attn_layer_num_for_pp) == len(transformer.blocks), (
"Sum of attn_layer_num_for_pp should be equal to the "
"number of transformer blocks"
)
if is_pipeline_first_stage():
transformer.blocks = transformer.blocks[: attn_layer_num_for_pp[0]]
else:
transformer.blocks = transformer.blocks[
sum(attn_layer_num_for_pp[: pp_rank - 1]) : sum(
attn_layer_num_for_pp[:pp_rank]
)
]
else:
num_blocks_per_stage = (
len(transformer.blocks) + pp_world_size - 1
) // pp_world_size
start_idx = pp_rank * num_blocks_per_stage
end_idx = min(
(pp_rank + 1) * num_blocks_per_stage,
len(transformer.blocks),
)
transformer.blocks = transformer.blocks[start_idx:end_idx]
# position embedding
if not is_pipeline_first_stage():
transformer.pos_embed = None
if not is_pipeline_last_stage():
transformer.norm_out = None
transformer.proj_out = None
return transformer
@xFuserBaseWrapper.forward_check_condition
def forward(
self,
hidden_states,
timestep,
encoder_hidden_states=None,
text_embedding_mask=None,
encoder_hidden_states_t5=None,
text_embedding_mask_t5=None,
image_meta_size=None,
style=None,
image_rotary_emb=None,
skips=None,
controlnet_block_samples=None,
return_dict=True,
):
"""
The [`HunyuanDiT2DModel`] forward method.
Args:
hidden_states (`torch.Tensor` of shape `(batch size, dim, height, width)`):
The input tensor.
timestep ( `torch.LongTensor`, *optional*):
Used to indicate denoising step.
encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
Conditional embeddings for cross attention layer. This is the output of `BertModel`.
text_embedding_mask: torch.Tensor
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output
of `BertModel`.
encoder_hidden_states_t5 ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
Conditional embeddings for cross attention layer. This is the output of T5 Text Encoder.
text_embedding_mask_t5: torch.Tensor
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output
of T5 Text Encoder.
image_meta_size (torch.Tensor):
Conditional embedding indicate the image sizes
style: torch.Tensor:
Conditional embedding indicate the style
image_rotary_emb (`torch.Tensor`):
The image rotary embeddings to apply on query and key tensors during attention calculation.
return_dict: bool
Whether to return a dictionary.
"""
#! ---------------------------------------- MODIFIED BELOW ----------------------------------------
assert controlnet_block_samples is None
# * get height & width from runtime state
height, width = self._get_patch_height_width()
# * only pp rank 0 needs pos_embed (patchify)
if is_pipeline_first_stage():
hidden_states = self.pos_embed(hidden_states)
#! ORIGIN
# height, width = hidden_states.shape[-2:]
# hidden_states = self.pos_embed(hidden_states)
#! ---------------------------------------- MODIFIED ABOVE ----------------------------------------
temb = self.time_extra_emb(
timestep,
encoder_hidden_states_t5,
image_meta_size,
style,
hidden_dtype=timestep.dtype,
) # [B, D]
# text projection
batch_size, sequence_length, _ = encoder_hidden_states_t5.shape
encoder_hidden_states_t5 = self.text_embedder(
encoder_hidden_states_t5.view(-1, encoder_hidden_states_t5.shape[-1])
)
encoder_hidden_states_t5 = encoder_hidden_states_t5.view(
batch_size, sequence_length, -1
)
encoder_hidden_states = torch.cat(
[encoder_hidden_states, encoder_hidden_states_t5], dim=1
)
text_embedding_mask = torch.cat(
[text_embedding_mask, text_embedding_mask_t5], dim=-1
)
text_embedding_mask = text_embedding_mask.unsqueeze(2).bool()
encoder_hidden_states = torch.where(
text_embedding_mask, encoder_hidden_states, self.text_embedding_padding
)
#! ---------------------------------------- MODIFIED BELOW ----------------------------------------
if get_pipeline_parallel_world_size() == 1:
skips = []
num_layers = len(self.blocks)
for layer, block in enumerate(self.blocks):
if layer > num_layers // 2:
skip = skips.pop()
hidden_states = block(
hidden_states,
temb=temb,
encoder_hidden_states=encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
skip=skip,
) # (N, L, D)
else:
hidden_states = block(
hidden_states,
temb=temb,
encoder_hidden_states=encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
) # (N, L, D)
if layer < (num_layers // 2 - 1):
skips.append(hidden_states)
else:
if get_pipeline_parallel_rank() >= get_pipeline_parallel_world_size() // 2:
assert skips is not None
skips = list(skips.unbind(0))
for layer, block in enumerate(self.blocks):
skip = skips.pop()
hidden_states = block(
hidden_states,
temb=temb,
encoder_hidden_states=encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
skip=skip,
) # (N, L, D)
assert len(skips) == 0
else:
skips = []
for layer, block in enumerate(self.blocks):
hidden_states = block(
hidden_states,
temb=temb,
encoder_hidden_states=encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
) # (N, L, D)
skips.append(hidden_states)
skips = torch.stack(skips, dim=0)
#! ---------------------------------------- MODIFIED ABOVE ----------------------------------------
# * only the last pp rank needs unpatchify
#! ---------------------------------------- ADD BELOW ----------------------------------------
if is_pipeline_last_stage():
#! ---------------------------------------- ADD ABOVE ----------------------------------------
# final layer
hidden_states = self.norm_out(hidden_states, temb.to(torch.float32))
hidden_states = self.proj_out(hidden_states)
# (N, L, patch_size ** 2 * out_channels)
# unpatchify: (N, out_channels, H, W)
patch_size = get_runtime_state().backbone_patch_size
hidden_states = hidden_states.reshape(
shape=(
hidden_states.shape[0],
height,
width,
patch_size,
patch_size,
self.out_channels,
)
)
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
output = hidden_states.reshape(
shape=(
hidden_states.shape[0],
self.out_channels,
height * patch_size,
width * patch_size,
)
)
#! ---------------------------------------- ADD BELOW ----------------------------------------
elif get_pipeline_parallel_rank() >= get_pipeline_parallel_world_size() // 2:
output = hidden_states
else:
output = hidden_states, skips
#! ---------------------------------------- ADD ABOVE ----------------------------------------
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
from typing import Optional, Dict, Any, Union, List, Optional, Tuple, Type
import torch
import torch.distributed
import torch.nn as nn
from diffusers.models.embeddings import PatchEmbed
from diffusers.models import LatteTransformer3DModel
from diffusers.models.transformers.transformer_2d import Transformer2DModelOutput
from diffusers.utils import (
is_torch_version,
scale_lora_layers,
USE_PEFT_BACKEND,
unscale_lora_layers,
)
from xfuser.model_executor.models import xFuserModelBaseWrapper
from xfuser.core.distributed.runtime_state import get_runtime_state
from xfuser.logger import init_logger
from xfuser.model_executor.base_wrapper import xFuserBaseWrapper
from xfuser.core.distributed import (
get_data_parallel_world_size,
get_sequence_parallel_world_size,
get_pipeline_parallel_world_size,
get_classifier_free_guidance_world_size,
get_classifier_free_guidance_rank,
get_pipeline_parallel_rank,
get_pp_group,
get_world_group,
get_cfg_group,
get_sp_group,
get_runtime_state,
initialize_runtime_state,
)
from xfuser.model_executor.models.transformers.register import (
xFuserTransformerWrappersRegister,
)
from xfuser.model_executor.models.transformers.base_transformer import (
xFuserTransformerBaseWrapper,
)
logger = init_logger(__name__)
@xFuserTransformerWrappersRegister.register(LatteTransformer3DModel)
class xFuserLatteTransformer3DWrapper(xFuserTransformerBaseWrapper):
def __init__(
self,
transformer: LatteTransformer3DModel,
):
super().__init__(
transformer=transformer,
submodule_classes_to_wrap=[nn.Conv2d, PatchEmbed],
submodule_name_to_wrap=["attn1"],
)
def forward(
self,
hidden_states: torch.Tensor,
timestep: Optional[torch.LongTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
enable_temporal_attentions: bool = True,
return_dict: bool = True,
):
"""
The [`LatteTransformer3DModel`] forward method.
Args:
hidden_states shape `(batch size, channel, num_frame, height, width)`:
Input `hidden_states`.
timestep ( `torch.LongTensor`, *optional*):
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
encoder_attention_mask ( `torch.Tensor`, *optional*):
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
* Mask `(batcheight, sequence_length)` True = keep, False = discard.
* Bias `(batcheight, 1, sequence_length)` 0 = keep, -10000 = discard.
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
above. This bias will be added to the cross-attention scores.
enable_temporal_attentions:
(`bool`, *optional*, defaults to `True`): Whether to enable temporal attentions.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
# Reshape hidden states
batch_size, channels, num_frame, height, width = hidden_states.shape
# batch_size channels num_frame height width -> (batch_size * num_frame) channels height width
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(
-1, channels, height, width
)
# Input
height, width = (
hidden_states.shape[-2] // self.config.patch_size,
hidden_states.shape[-1] // self.config.patch_size,
)
num_patches = height * width
hidden_states = self.pos_embed(
hidden_states
) # alrady add positional embeddings
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
timestep, embedded_timestep = self.adaln_single(
timestep,
added_cond_kwargs=added_cond_kwargs,
batch_size=batch_size,
hidden_dtype=hidden_states.dtype,
)
# Prepare text embeddings for spatial block
# batch_size num_tokens hidden_size -> (batch_size * num_frame) num_tokens hidden_size
encoder_hidden_states = self.caption_projection(
encoder_hidden_states
) # 3 120 1152
encoder_hidden_states_spatial = encoder_hidden_states.repeat_interleave(
num_frame, dim=0
).view(-1, encoder_hidden_states.shape[-2], encoder_hidden_states.shape[-1])
# Prepare timesteps for spatial and temporal block
timestep_spatial = timestep.repeat_interleave(num_frame, dim=0).view(
-1, timestep.shape[-1]
)
timestep_temp = timestep.repeat_interleave(num_patches, dim=0).view(
-1, timestep.shape[-1]
)
# Spatial and temporal transformer blocks
for i, (spatial_block, temp_block) in enumerate(
zip(self.transformer_blocks, self.temporal_transformer_blocks)
):
if self.training and self.gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint(
spatial_block,
hidden_states,
None, # attention_mask
encoder_hidden_states_spatial,
encoder_attention_mask,
timestep_spatial,
None, # cross_attention_kwargs
None, # class_labels
use_reentrant=False,
)
else:
hidden_states = spatial_block(
hidden_states,
None, # attention_mask
encoder_hidden_states_spatial,
encoder_attention_mask,
timestep_spatial,
None, # cross_attention_kwargs
None, # class_labels
)
if enable_temporal_attentions:
# (batch_size * num_frame) num_tokens hidden_size -> (batch_size * num_tokens) num_frame hidden_size
hidden_states = hidden_states.reshape(
batch_size, -1, hidden_states.shape[-2], hidden_states.shape[-1]
).permute(0, 2, 1, 3)
hidden_states = hidden_states.reshape(
-1, hidden_states.shape[-2], hidden_states.shape[-1]
)
if i == 0 and num_frame > 1:
hidden_states = hidden_states + self.temp_pos_embed
if self.training and self.gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint(
temp_block,
hidden_states,
None, # attention_mask
None, # encoder_hidden_states
None, # encoder_attention_mask
timestep_temp,
None, # cross_attention_kwargs
None, # class_labels
use_reentrant=False,
)
else:
hidden_states = temp_block(
hidden_states,
None, # attention_mask
None, # encoder_hidden_states
None, # encoder_attention_mask
timestep_temp,
None, # cross_attention_kwargs
None, # class_labels
)
# (batch_size * num_tokens) num_frame hidden_size -> (batch_size * num_frame) num_tokens hidden_size
hidden_states = hidden_states.reshape(
batch_size, -1, hidden_states.shape[-2], hidden_states.shape[-1]
).permute(0, 2, 1, 3)
hidden_states = hidden_states.reshape(
-1, hidden_states.shape[-2], hidden_states.shape[-1]
)
embedded_timestep = embedded_timestep.repeat_interleave(num_frame, dim=0).view(
-1, embedded_timestep.shape[-1]
)
shift, scale = (
self.scale_shift_table[None] + embedded_timestep[:, None]
).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)
# Modulation
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.proj_out(hidden_states)
# unpatchify
if self.adaln_single is None:
height = width = int(hidden_states.shape[1] ** 0.5)
hidden_states = hidden_states.reshape(
shape=(
-1,
height,
width,
self.config.patch_size,
self.config.patch_size,
self.out_channels,
)
)
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
output = hidden_states.reshape(
shape=(
-1,
self.out_channels,
height * self.config.patch_size,
width * self.config.patch_size,
)
)
output = output.reshape(
batch_size, -1, output.shape[-3], output.shape[-2], output.shape[-1]
).permute(0, 2, 1, 3, 4)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
from typing import Optional, Dict, Any
import torch
import torch.distributed
import torch.nn as nn
from diffusers import PixArtTransformer2DModel
from diffusers.models.embeddings import PatchEmbed
from diffusers.models.transformers.transformer_2d import Transformer2DModelOutput
from diffusers.utils import is_torch_version
from xfuser.logger import init_logger
from xfuser.model_executor.base_wrapper import xFuserBaseWrapper
from xfuser.core.distributed import is_pipeline_first_stage, is_pipeline_last_stage
from .register import xFuserTransformerWrappersRegister
from .base_transformer import xFuserTransformerBaseWrapper
logger = init_logger(__name__)
@xFuserTransformerWrappersRegister.register(PixArtTransformer2DModel)
class xFuserPixArtTransformer2DWrapper(xFuserTransformerBaseWrapper):
def __init__(
self,
transformer: PixArtTransformer2DModel,
):
super().__init__(
transformer=transformer,
submodule_classes_to_wrap=[nn.Conv2d, PatchEmbed],
submodule_name_to_wrap=["attn1"],
)
@xFuserBaseWrapper.forward_check_condition
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
added_cond_kwargs: Dict[str, torch.Tensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
return_dict: bool = True,
):
"""
The [`PixArtTransformer2DModel`] forward method.
Args:
hidden_states (`torch.Tensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`.
encoder_hidden_states (`torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
timestep (`torch.LongTensor`, *optional*):
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
added_cond_kwargs: (`Dict[str, Any]`, *optional*): Additional conditions to be used as inputs.
cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
attention_mask ( `torch.Tensor`, *optional*):
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
negative values to the attention scores corresponding to "discard" tokens.
encoder_attention_mask ( `torch.Tensor`, *optional*):
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
* Mask `(batch, sequence_length)` True = keep, False = discard.
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
above. This bias will be added to the cross-attention scores.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if self.use_additional_conditions and added_cond_kwargs is None:
raise ValueError(
"`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
)
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
# expects mask of shape:
# [batch, key_tokens]
# adds singleton query_tokens dimension:
# [batch, 1, key_tokens]
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
if attention_mask is not None and attention_mask.ndim == 2:
# assume that mask is expressed as:
# (1 = keep, 0 = discard)
# convert mask into a bias that can be added to attention scores:
# (keep = +0, discard = -10000.0)
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
encoder_attention_mask = (
1 - encoder_attention_mask.to(hidden_states.dtype)
) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
# 1. Input
batch_size = hidden_states.shape[0]
#! ---------------------------------------- MODIFIED BELOW ----------------------------------------
# * get height & width from runtime state
height, width = self._get_patch_height_width()
# * only pp rank 0 needs pos_embed (patchify)
if is_pipeline_first_stage():
hidden_states = self.pos_embed(hidden_states)
#! ORIGIN
# height, width = (
# hidden_states.shape[-2] // self.config.patch_size,
# hidden_states.shape[-1] // self.config.patch_size,
# )
# hidden_states = self.pos_embed(hidden_states)
#! ---------------------------------------- MODIFIED ABOVE ----------------------------------------
timestep, embedded_timestep = self.adaln_single(
timestep,
added_cond_kwargs,
batch_size=batch_size,
hidden_dtype=hidden_states.dtype,
)
if self.caption_projection is not None:
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.view(
batch_size, -1, hidden_states.shape[-1]
)
# 2. Blocks
for i, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = (
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
timestep,
cross_attention_kwargs,
None,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
class_labels=None,
)
# 3. Output
# * only the last pp rank needs unpatchify
#! ---------------------------------------- ADD BELOW ----------------------------------------
if is_pipeline_last_stage():
#! ---------------------------------------- ADD ABOVE ----------------------------------------
shift, scale = (
self.scale_shift_table[None]
+ embedded_timestep[:, None].to(self.scale_shift_table.device)
).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)
# Modulation
hidden_states = hidden_states * (
1 + scale.to(hidden_states.device)
) + shift.to(hidden_states.device)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.squeeze(1)
# unpatchify
hidden_states = hidden_states.reshape(
shape=(
-1,
height,
width,
self.config.patch_size,
self.config.patch_size,
self.out_channels,
)
)
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
output = hidden_states.reshape(
shape=(
-1,
self.out_channels,
height * self.config.patch_size,
width * self.config.patch_size,
)
)
#! ---------------------------------------- ADD BELOW ----------------------------------------
else:
output = hidden_states
#! ---------------------------------------- ADD ABOVE ----------------------------------------
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
from typing import Dict, Type
import torch
import torch.nn as nn
from xfuser.logger import init_logger
from xfuser.model_executor.models.transformers.base_transformer import (
xFuserTransformerBaseWrapper,
)
logger = init_logger(__name__)
class xFuserTransformerWrappersRegister:
_XFUSER_TRANSFORMER_MAPPING: Dict[
Type[nn.Module], Type[xFuserTransformerBaseWrapper]
] = {}
@classmethod
def register(cls, origin_transformer_class: Type[nn.Module]):
def decorator(xfuser_transformer_class: Type[nn.Module]):
if not issubclass(
xfuser_transformer_class, xFuserTransformerBaseWrapper
):
raise ValueError(
f"{xfuser_transformer_class.__class__.__name__} is not "
f"a subclass of xFuserTransformerBaseWrapper"
)
cls._XFUSER_TRANSFORMER_MAPPING[origin_transformer_class] = (
xfuser_transformer_class
)
return xfuser_transformer_class
return decorator
@classmethod
def get_wrapper(cls, transformer: nn.Module) -> xFuserTransformerBaseWrapper:
candidate = None
candidate_origin = None
for (
origin_transformer_class,
wrapper_class,
) in cls._XFUSER_TRANSFORMER_MAPPING.items():
if isinstance(transformer, origin_transformer_class):
if (
candidate is None
or origin_transformer_class == transformer.__class__
or issubclass(origin_transformer_class, candidate_origin)
):
candidate_origin = origin_transformer_class
candidate = wrapper_class
if candidate is None:
raise ValueError(
f"Transformer class {transformer.__class__.__name__} "
f"is not supported by xFuser"
)
else:
return candidate
from typing import Optional, Dict, Any, Union
import torch
import torch.distributed
import torch.nn as nn
from diffusers.models.embeddings import PatchEmbed
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from diffusers.models.transformers.transformer_2d import Transformer2DModelOutput
from diffusers.utils import (
is_torch_version,
scale_lora_layers,
USE_PEFT_BACKEND,
unscale_lora_layers,
)
from xfuser.core.distributed.parallel_state import (
get_tensor_model_parallel_world_size,
is_pipeline_first_stage,
is_pipeline_last_stage,
)
from xfuser.core.distributed.runtime_state import get_runtime_state
from xfuser.logger import init_logger
from xfuser.model_executor.models.transformers.register import (
xFuserTransformerWrappersRegister,
)
from xfuser.model_executor.models.transformers.base_transformer import (
xFuserTransformerBaseWrapper,
)
logger = init_logger(__name__)
from diffusers.models.attention import FeedForward
@xFuserTransformerWrappersRegister.register(FluxTransformer2DModel)
class xFuserFluxTransformer2DWrapper(xFuserTransformerBaseWrapper):
def __init__(
self,
transformer: FluxTransformer2DModel,
):
super().__init__(
transformer=transformer,
submodule_classes_to_wrap=(
[FeedForward] if get_tensor_model_parallel_world_size() > 1 else []
),
submodule_name_to_wrap=["attn"],
transformer_blocks_name=["transformer_blocks", "single_transformer_blocks"],
)
self.encoder_hidden_states_cache = [
None for _ in range(len(self.transformer_blocks))
]
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
pooled_projections: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
The [`FluxTransformer2DModel`] forward method.
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
A list of tensors that if specified are added to the residuals of transformer blocks.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if (
joint_attention_kwargs is not None
and joint_attention_kwargs.get("scale", None) is not None
):
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
if is_pipeline_first_stage():
hidden_states = self.x_embedder(hidden_states)
timestep = timestep.to(hidden_states.dtype) * 1000
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
else:
guidance = None
temb = (
self.time_text_embed(timestep, pooled_projections)
if guidance is None
else self.time_text_embed(timestep, guidance, pooled_projections)
)
if is_pipeline_first_stage():
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
if txt_ids.ndim == 3:
logger.warning(
"Passing `txt_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
txt_ids = txt_ids[0]
if img_ids.ndim == 3:
logger.warning(
"Passing `img_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
img_ids = img_ids[0]
ids = torch.cat((txt_ids, img_ids), dim=0)
image_rotary_emb = self.pos_embed(ids)
for index_block, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = (
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
)
encoder_hidden_states, hidden_states = (
torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
# controlnet residual
# if controlnet_block_samples is not None:
# interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
# interval_control = int(np.ceil(interval_control))
# hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
# if self.stage_info.after_flags["transformer_blocks"]:
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = (
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
# controlnet residual
# if controlnet_single_block_samples is not None:
# interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
# interval_control = int(np.ceil(interval_control))
# hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
# hidden_states[:, encoder_hidden_states.shape[1] :, ...]
# + controlnet_single_block_samples[index_block // interval_control]
# )
encoder_hidden_states = hidden_states[:, : encoder_hidden_states.shape[1], ...]
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
if self.stage_info.after_flags["single_transformer_blocks"]:
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states), None
else:
output = hidden_states, encoder_hidden_states
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
if __name__ == "__main__":
# print module in FluxTransformer2DModel
model = FluxTransformer2DModel()
from typing import Optional, Dict, Any, Union
import torch
import torch.distributed
import torch.nn as nn
from diffusers.models.embeddings import PatchEmbed
from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
from diffusers.models.transformers.transformer_2d import Transformer2DModelOutput
from diffusers.utils import (
is_torch_version,
scale_lora_layers,
USE_PEFT_BACKEND,
unscale_lora_layers,
)
from xfuser.core.distributed.runtime_state import get_runtime_state
from xfuser.logger import init_logger
from xfuser.model_executor.base_wrapper import xFuserBaseWrapper
from xfuser.core.distributed import is_pipeline_first_stage, is_pipeline_last_stage
from .register import xFuserTransformerWrappersRegister
from .base_transformer import xFuserTransformerBaseWrapper
logger = init_logger(__name__)
@xFuserTransformerWrappersRegister.register(SD3Transformer2DModel)
class xFuserSD3Transformer2DWrapper(xFuserTransformerBaseWrapper):
def __init__(
self,
transformer: SD3Transformer2DModel,
):
super().__init__(
transformer=transformer,
submodule_classes_to_wrap=[nn.Conv2d, PatchEmbed],
submodule_name_to_wrap=["attn"],
)
self.encoder_hidden_states_cache = [
None for _ in range(len(self.transformer_blocks))
]
def forward(
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
pooled_projections: torch.FloatTensor = None,
timestep: torch.LongTensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
The [`SD3Transformer2DModel`] forward method.
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
elif joint_attention_kwargs and "scale" in joint_attention_kwargs:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
#! ---------------------------------------- MODIFIED BELOW ----------------------------------------
# * get height & width from runtime state
height, width = self._get_patch_height_width()
# * only pp rank 0 needs pos_embed (patchify)
if is_pipeline_first_stage():
hidden_states = self.pos_embed(
hidden_states
) # takes care of adding positional embeddings too.
#! ORIGIN:
# height, width = hidden_states.shape[-2:]
# hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
#! ---------------------------------------- MODIFIED ABOVE ----------------------------------------
temb = self.time_text_embed(timestep, pooled_projections)
#! ---------------------------------------- ADD BELOW ----------------------------------------
if is_pipeline_first_stage():
#! ---------------------------------------- ADD ABOVE ----------------------------------------
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
for i, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = (
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
**ckpt_kwargs,
)
else:
if (
get_runtime_state().patch_mode
and get_runtime_state().pipeline_patch_idx == 0
):
self.encoder_hidden_states_cache[i] = encoder_hidden_states
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
)
elif get_runtime_state().patch_mode:
_, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=self.encoder_hidden_states_cache[i],
temb=temb,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
)
# * only the last pp rank needs unpatchify
#! ---------------------------------------- ADD BELOW ----------------------------------------
if is_pipeline_last_stage():
#! ---------------------------------------- ADD ABOVE ----------------------------------------
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)
# unpatchify
patch_size = self.config.patch_size
hidden_states = hidden_states.reshape(
shape=(
hidden_states.shape[0],
height,
width,
patch_size,
patch_size,
self.out_channels,
)
)
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
output = (
hidden_states.reshape(
shape=(
hidden_states.shape[0],
self.out_channels,
height * patch_size,
width * patch_size,
)
),
None,
)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
#! ---------------------------------------- ADD BELOW ----------------------------------------
else:
output = hidden_states, encoder_hidden_states
#! ---------------------------------------- ADD ABOVE ----------------------------------------
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
from .base_pipeline import xFuserPipelineBaseWrapper
from .pipeline_pixart_alpha import xFuserPixArtAlphaPipeline
from .pipeline_pixart_sigma import xFuserPixArtSigmaPipeline
from .pipeline_stable_diffusion_3 import xFuserStableDiffusion3Pipeline
from .pipeline_flux import xFuserFluxPipeline
from .pipeline_latte import xFuserLattePipeline
from .pipeline_cogvideox import xFuserCogVideoXPipeline
from .pipeline_hunyuandit import xFuserHunyuanDiTPipeline
__all__ = [
"xFuserPipelineBaseWrapper",
"xFuserPixArtAlphaPipeline",
"xFuserPixArtSigmaPipeline",
"xFuserStableDiffusion3Pipeline",
"xFuserFluxPipeline",
"xFuserLattePipeline",
"xFuserHunyuanDiTPipeline",
"xFuserCogVideoXPipeline",
]
\ No newline at end of file
from abc import ABCMeta, abstractmethod
from functools import wraps
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.distributed
import torch.nn as nn
from diffusers import DiffusionPipeline
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from distvae.modules.adapters.vae.decoder_adapters import DecoderAdapter
from xfuser.config.config import (
EngineConfig,
InputConfig,
)
from xfuser.core.distributed.parallel_state import get_tensor_model_parallel_world_size
from xfuser.logger import init_logger
from xfuser.core.distributed import (
get_data_parallel_world_size,
get_sequence_parallel_world_size,
get_pipeline_parallel_world_size,
get_classifier_free_guidance_world_size,
get_classifier_free_guidance_rank,
is_pipeline_first_stage,
is_pipeline_last_stage,
get_pp_group,
get_world_group,
get_runtime_state,
initialize_runtime_state,
is_dp_last_group,
)
from xfuser.core.fast_attention import (
get_fast_attn_enable,
initialize_fast_attn_state,
fast_attention_compression,
)
from xfuser.model_executor.base_wrapper import xFuserBaseWrapper
from xfuser.envs import PACKAGES_CHECKER
PACKAGES_CHECKER.check_diffusers_version()
from xfuser.model_executor.schedulers import *
from xfuser.model_executor.models.transformers import *
from xfuser.model_executor.layers.attention_processor import *
try:
import os
from onediff.infer_compiler import compile as od_compile
HAS_OF = True
os.environ["NEXFORT_FUSE_TIMESTEP_EMBEDDING"] = "0"
os.environ["NEXFORT_FX_FORCE_TRITON_SDPA"] = "1"
except:
HAS_OF = False
logger = init_logger(__name__)
class xFuserPipelineBaseWrapper(xFuserBaseWrapper, metaclass=ABCMeta):
def __init__(
self,
pipeline: DiffusionPipeline,
engine_config: EngineConfig,
):
self.module: DiffusionPipeline
self._init_runtime_state(pipeline=pipeline, engine_config=engine_config)
self._init_fast_attn_state(pipeline=pipeline, engine_config=engine_config)
# backbone
transformer = getattr(pipeline, "transformer", None)
unet = getattr(pipeline, "unet", None)
# vae
vae = getattr(pipeline, "vae", None)
# scheduler
scheduler = getattr(pipeline, "scheduler", None)
if transformer is not None:
pipeline.transformer = self._convert_transformer_backbone(
transformer,
enable_torch_compile=engine_config.runtime_config.use_torch_compile,
enable_onediff=engine_config.runtime_config.use_onediff,
)
elif unet is not None:
pipeline.unet = self._convert_unet_backbone(unet)
if scheduler is not None:
pipeline.scheduler = self._convert_scheduler(scheduler)
if vae is not None and engine_config.runtime_config.use_parallel_vae and not self.use_naive_forward():
pipeline.vae = self._convert_vae(vae)
super().__init__(module=pipeline)
def reset_activation_cache(self):
if hasattr(self.module, "transformer") and hasattr(
self.module.transformer, "reset_activation_cache"
):
self.module.transformer.reset_activation_cache()
if hasattr(self.module, "unet") and hasattr(
self.module.unet, "reset_activation_cache"
):
self.module.unet.reset_activation_cache()
if hasattr(self.module, "vae") and hasattr(
self.module.vae, "reset_activation_cache"
):
self.module.vae.reset_activation_cache()
if hasattr(self.module, "scheduler") and hasattr(
self.module.scheduler, "reset_activation_cache"
):
self.module.scheduler.reset_activation_cache()
def to(self, *args, **kwargs):
self.module = self.module.to(*args, **kwargs)
return self
@staticmethod
def enable_fast_attn(func):
@wraps(func)
def fast_attn_fn(self, *args, **kwargs):
if get_fast_attn_enable():
for block in self.module.transformer.transformer_blocks:
for layer in block.children():
if isinstance(layer, xFuserAttentionBaseWrapper):
layer.stepi = 0
layer.cached_residual = None
layer.cached_output = None
out = func(self, *args, **kwargs)
for block in self.module.transformer.transformer_blocks:
for layer in block.children():
if isinstance(layer, xFuserAttentionBaseWrapper):
layer.stepi = 0
layer.cached_residual = None
layer.cached_output = None
return out
else:
return func(self, *args, **kwargs)
return fast_attn_fn
@staticmethod
def enable_data_parallel(func):
@wraps(func)
def data_parallel_fn(self, *args, **kwargs):
prompt = kwargs.get("prompt", None)
negative_prompt = kwargs.get("negative_prompt", "")
# dp_degree <= batch_size
batch_size = len(prompt) if isinstance(prompt, list) else 1
if batch_size > 1:
dp_degree = get_runtime_state().parallel_config.dp_degree
dp_group_rank = get_world_group().rank // (
get_world_group().world_size // get_data_parallel_world_size()
)
dp_group_batch_size = (batch_size + dp_degree - 1) // dp_degree
start_batch_idx = dp_group_rank * dp_group_batch_size
end_batch_idx = min(
(dp_group_rank + 1) * dp_group_batch_size, batch_size
)
prompt = prompt[start_batch_idx:end_batch_idx]
if isinstance(negative_prompt, List):
negative_prompt = negative_prompt[start_batch_idx:end_batch_idx]
kwargs["prompt"] = prompt
if "negative_prompt" in kwargs:
kwargs["negative_prompt"] = negative_prompt
return func(self, *args, **kwargs)
return data_parallel_fn
def use_naive_forward(self):
return (
get_pipeline_parallel_world_size() == 1
and get_classifier_free_guidance_world_size() == 1
and get_sequence_parallel_world_size() == 1
and get_tensor_model_parallel_world_size() == 1
and get_fast_attn_enable() == False
)
@staticmethod
def check_to_use_naive_forward(func):
@wraps(func)
def check_naive_forward_fn(self, *args, **kwargs):
if self.use_naive_forward():
return self.module(*args, **kwargs)
else:
return func(self, *args, **kwargs)
return check_naive_forward_fn
@staticmethod
def check_model_parallel_state(
cfg_parallel_available: bool = True,
sequence_parallel_available: bool = True,
pipefusion_parallel_available: bool = True,
):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
if (
not cfg_parallel_available
and get_runtime_state().parallel_config.cfg_degree > 1
):
raise RuntimeError("CFG parallelism is not supported by the model")
if (
not sequence_parallel_available
and get_runtime_state().parallel_config.sp_degree > 1
):
raise RuntimeError(
"Sequence parallelism is not supported by the model"
)
if (
not pipefusion_parallel_available
and get_runtime_state().parallel_config.pp_degree > 1
):
raise RuntimeError(
"Pipefusion parallelism is not supported by the model"
)
return func(*args, **kwargs)
return wrapper
return decorator
def forward(self):
pass
def prepare_run(
self, input_config: InputConfig, steps: int = 3, sync_steps: int = 1
):
if get_fast_attn_enable():
# set compression methods for DiTFastAttn
fast_attention_compression(self)
prompt = [""] * input_config.batch_size if input_config.batch_size > 1 else ""
warmup_steps = get_runtime_state().runtime_config.warmup_steps
get_runtime_state().runtime_config.warmup_steps = sync_steps
self.__call__(
height=input_config.height,
width=input_config.width,
prompt=prompt,
use_resolution_binning=input_config.use_resolution_binning,
num_inference_steps=steps,
generator=torch.Generator(device="cuda").manual_seed(42),
output_type=input_config.output_type,
)
get_runtime_state().runtime_config.warmup_steps = warmup_steps
def latte_prepare_run(
self, input_config: InputConfig, steps: int = 3, sync_steps: int = 1
):
prompt = [""] * input_config.batch_size if input_config.batch_size > 1 else ""
warmup_steps = get_runtime_state().runtime_config.warmup_steps
get_runtime_state().runtime_config.warmup_steps = sync_steps
self.__call__(
height=input_config.height,
width=input_config.width,
prompt=prompt,
# use_resolution_binning=input_config.use_resolution_binning,
num_inference_steps=steps,
output_type="latent",
generator=torch.Generator(device="cuda").manual_seed(42),
)
get_runtime_state().runtime_config.warmup_steps = warmup_steps
def _init_runtime_state(
self, pipeline: DiffusionPipeline, engine_config: EngineConfig
):
initialize_runtime_state(pipeline=pipeline, engine_config=engine_config)
def _init_fast_attn_state(
self, pipeline: DiffusionPipeline, engine_config: EngineConfig
):
initialize_fast_attn_state(pipeline=pipeline, single_config=engine_config.fast_attn_config)
def _convert_transformer_backbone(
self, transformer: nn.Module, enable_torch_compile: bool, enable_onediff: bool
):
if (
get_pipeline_parallel_world_size() == 1
and get_sequence_parallel_world_size() == 1
and get_classifier_free_guidance_world_size() == 1
and get_tensor_model_parallel_world_size() == 1
and get_fast_attn_enable() == False
):
logger.info(
"Transformer backbone found, but model parallelism is not enabled, "
"use naive model"
)
else:
logger.info("Transformer backbone found, paralleling transformer...")
wrapper = xFuserTransformerWrappersRegister.get_wrapper(transformer)
transformer = wrapper(transformer)
if enable_torch_compile and enable_onediff:
logger.warning(
f"apply --use_torch_compile and --use_onediff togather. we use torch compile only"
)
if enable_torch_compile or enable_onediff:
if getattr(transformer, "forward") is not None:
if enable_torch_compile:
optimized_transformer_forward = torch.compile(
getattr(transformer, "forward")
)
elif enable_onediff:
# O3: +fp16 reduction
if not HAS_OF:
raise RuntimeError(
"install onediff and nexfort to --use_onediff"
)
options = {"mode": "O3"} # mode can be O2 or O3
optimized_transformer_forward = od_compile(
getattr(transformer, "forward"),
backend="nexfort",
options=options,
)
setattr(transformer, "forward", optimized_transformer_forward)
else:
raise AttributeError(
f"Transformer backbone type: {transformer.__class__.__name__} has no attribute 'forward'"
)
return transformer
def _convert_unet_backbone(
self,
unet: nn.Module,
):
logger.info("UNet Backbone found")
raise NotImplementedError("UNet parallelisation is not supported yet")
def _convert_scheduler(
self,
scheduler: nn.Module,
):
logger.info("Scheduler found, paralleling scheduler...")
wrapper = xFuserSchedulerWrappersRegister.get_wrapper(scheduler)
scheduler = wrapper(scheduler)
return scheduler
def _convert_vae(
self,
vae: AutoencoderKL,
):
logger.info("VAE found, paralleling vae...")
vae.decoder = DecoderAdapter(vae.decoder)
return vae
@abstractmethod
def __call__(self):
pass
def _init_sync_pipeline(self, latents: torch.Tensor):
get_runtime_state().set_patched_mode(patch_mode=False)
latents_list = [
latents[:, :, start_idx:end_idx, :]
for start_idx, end_idx in get_runtime_state().pp_patches_start_end_idx_global
]
latents = torch.cat(latents_list, dim=-2)
return latents
def _init_video_sync_pipeline(self, latents: torch.Tensor):
get_runtime_state().set_patched_mode(patch_mode=False)
latents_list = [
latents[:, :, :, start_idx:end_idx, :]
for start_idx, end_idx in get_runtime_state().pp_patches_start_end_idx_global
]
latents = torch.cat(latents_list, dim=-2)
return latents
def _init_async_pipeline(
self,
num_timesteps: int,
latents: torch.Tensor,
num_pipeline_warmup_steps: int,
):
get_runtime_state().set_patched_mode(patch_mode=True)
if is_pipeline_first_stage():
# get latents computed in warmup stage
# ignore latents after the last timestep
latents = (
get_pp_group().pipeline_recv()
if num_pipeline_warmup_steps > 0
else latents
)
patch_latents = list(
latents.split(get_runtime_state().pp_patches_height, dim=2)
)
elif is_pipeline_last_stage():
patch_latents = list(
latents.split(get_runtime_state().pp_patches_height, dim=2)
)
else:
patch_latents = [
None for _ in range(get_runtime_state().num_pipeline_patch)
]
recv_timesteps = (
num_timesteps - 1 if is_pipeline_first_stage() else num_timesteps
)
for _ in range(recv_timesteps):
for patch_idx in range(get_runtime_state().num_pipeline_patch):
get_pp_group().add_pipeline_recv_task(patch_idx)
return patch_latents
def _process_cfg_split_batch(
self,
negative_embeds: torch.Tensor,
embeds: torch.Tensor,
negative_embdes_mask: torch.Tensor = None,
embeds_mask: torch.Tensor = None,
):
if get_classifier_free_guidance_world_size() == 1:
embeds = torch.cat([negative_embeds, embeds], dim=0)
elif get_classifier_free_guidance_rank() == 0:
embeds = negative_embeds
elif get_classifier_free_guidance_rank() == 1:
embeds = embeds
else:
raise ValueError("Invalid classifier free guidance rank")
if negative_embdes_mask is None:
return embeds
if get_classifier_free_guidance_world_size() == 1:
embeds_mask = torch.cat([negative_embdes_mask, embeds_mask], dim=0)
elif get_classifier_free_guidance_rank() == 0:
embeds_mask = negative_embdes_mask
elif get_classifier_free_guidance_rank() == 1:
embeds_mask = embeds_mask
else:
raise ValueError("Invalid classifier free guidance rank")
return embeds, embeds_mask
def is_dp_last_group(self):
"""Return True if in the last data parallel group, False otherwise.
Also include parallel vae situation.
"""
if get_runtime_state().runtime_config.use_parallel_vae and not self.use_naive_forward():
return get_world_group().rank == 0
else:
return is_dp_last_group()
def gather_broadcast_latents(self, latents:torch.Tensor):
"""gather latents from dp last group and broacast final latents
"""
# ---------gather latents from dp last group-----------
rank = get_world_group().rank
device = f"cuda:{rank}"
# all gather dp last group rank list
dp_rank_list = [torch.zeros(1, dtype=int, device=device) for _ in range(get_world_group().world_size)]
if is_dp_last_group():
gather_rank = int(rank)
else:
gather_rank = -1
torch.distributed.all_gather(dp_rank_list, torch.tensor([gather_rank],dtype=int,device=device))
dp_rank_list = [int(dp_rank[0]) for dp_rank in dp_rank_list if int(dp_rank[0])!=-1]
dp_last_group = torch.distributed.new_group(dp_rank_list)
# gather latents from dp last group
if rank == dp_rank_list[-1]:
latents_list = [torch.zeros_like(latents) for _ in dp_rank_list]
else:
latents_list = None
if rank in dp_rank_list:
torch.distributed.gather(latents, latents_list, dst=dp_rank_list[-1], group=dp_last_group)
if rank == dp_rank_list[-1]:
latents = torch.cat(latents_list,dim=0)
# ------broadcast latents to all nodes---------
src = dp_rank_list[-1]
latents_shape_len = torch.zeros(1,dtype=torch.int,device=device)
# broadcast latents shape len
if rank == src:
latents_shape_len[0] = len(latents.shape)
get_world_group().broadcast(latents_shape_len,src=src)
# broadcast latents shape
if rank == src:
input_shape = torch.tensor(latents.shape,dtype=torch.int,device=device)
else:
input_shape = torch.zeros(latents_shape_len[0],dtype=torch.int,device=device)
get_world_group().broadcast(input_shape,src=src)
# broadcast latents
if rank != src:
dtype = get_runtime_state().runtime_config.dtype
latents = torch.zeros(torch.Size(input_shape),dtype=dtype,device=device)
get_world_group().broadcast(latents,src=src)
return latents
import os
from typing import Any, List, Tuple, Callable, Optional, Union, Dict
import torch
import torch.distributed
from diffusers import CogVideoXPipeline
from diffusers.pipelines.cogvideo.pipeline_cogvideox import (
CogVideoXPipelineOutput,
retrieve_timesteps,
)
from diffusers.schedulers import CogVideoXDPMScheduler
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
import math
from xfuser.config import EngineConfig
from xfuser.core.distributed import (
get_pipeline_parallel_world_size,
get_sequence_parallel_world_size,
get_sequence_parallel_rank,
get_classifier_free_guidance_world_size,
get_cfg_group,
get_sp_group,
get_runtime_state,
is_dp_last_group,
)
from xfuser.model_executor.pipelines import xFuserPipelineBaseWrapper
from .register import xFuserPipelineWrapperRegister
@xFuserPipelineWrapperRegister.register(CogVideoXPipeline)
class xFuserCogVideoXPipeline(xFuserPipelineBaseWrapper):
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
engine_config: EngineConfig,
**kwargs,
):
pipeline = CogVideoXPipeline.from_pretrained(
pretrained_model_name_or_path, **kwargs
)
return cls(pipeline, engine_config)
@torch.no_grad()
@xFuserPipelineBaseWrapper.enable_data_parallel
@xFuserPipelineBaseWrapper.check_to_use_naive_forward
def __call__(
self,
prompt: Optional[Union[str, List[str]]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_frames: Optional[int] = None,
num_inference_steps: int = 50,
timesteps: Optional[List[int]] = None,
guidance_scale: float = 6,
use_dynamic_cfg: bool = False,
num_videos_per_prompt: int = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: str = "pil",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 226,
) -> Union[CogVideoXPipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
The height in pixels of the generated image. This is set to 480 by default for the best results.
width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
The width in pixels of the generated image. This is set to 720 by default for the best results.
num_frames (`int`, defaults to `48`):
Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
num_seconds is 6 and fps is 8. However, since videos can be saved at any fps, the only condition that
needs to be satisfied is that of divisibility mentioned above.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 7.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of videos to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int`, defaults to `226`):
Maximum sequence length in encoded prompt. Must be consistent with
`self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
Examples:
Returns:
[`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`:
[`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is a list with the generated images.
"""
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
num_frames = num_frames or self.transformer.config.sample_frames
num_videos_per_prompt = 1
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
height,
width,
negative_prompt,
callback_on_step_end_tensor_inputs,
prompt_embeds,
negative_prompt_embeds,
)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._interrupt = False
# 2. Default call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
get_runtime_state().set_video_input_parameters(
height=height,
width=width,
num_frames=num_frames,
batch_size=batch_size,
num_inference_steps=num_inference_steps,
split_text_embed_in_sp=get_pipeline_parallel_world_size() == 1,
)
# 3. Encode input prompt
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
negative_prompt,
do_classifier_free_guidance,
num_videos_per_prompt=num_videos_per_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
max_sequence_length=max_sequence_length,
device=device,
)
prompt_embeds = self._process_cfg_split_batch(
negative_prompt_embeds, prompt_embeds
)
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
self._num_timesteps = len(timesteps)
# 5. Prepare latents.
latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
# For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
patch_size_t = self.transformer.config.patch_size_t
additional_frames = 0
if patch_size_t is not None and latent_frames % patch_size_t != 0:
additional_frames = patch_size_t - latent_frames % patch_size_t
num_frames += additional_frames * self.vae_scale_factor_temporal
latent_channels = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
latent_channels,
num_frames,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Create rotary embeds if required
image_rotary_emb = (
self._prepare_rotary_positional_embeddings(
height, width, latents.size(1), device
)
if self.transformer.config.use_rotary_positional_embeddings
else None
)
# 8. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
p_t = self.transformer.config.patch_size_t or 1
latents, prompt_embeds, image_rotary_emb = self._init_sync_pipeline(
latents, prompt_embeds, image_rotary_emb,
(latents.size(1) + p_t - 1) // p_t
)
with self.progress_bar(total=num_inference_steps) as progress_bar:
# for DPM-solver++
old_pred_original_sample = None
for i, t in enumerate(timesteps):
if self.interrupt:
continue
if do_classifier_free_guidance:
latent_model_input = torch.cat(
[latents] * (2 // get_classifier_free_guidance_world_size())
)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
# predict noise model_output
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred.float()
# perform guidance
if use_dynamic_cfg:
self._guidance_scale = 1 + guidance_scale * (
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
)
if do_classifier_free_guidance:
if get_classifier_free_guidance_world_size() == 1:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
elif get_classifier_free_guidance_world_size() == 2:
noise_pred_uncond, noise_pred_text = get_cfg_group().all_gather(
noise_pred, separate_tensors=True
)
noise_pred = noise_pred_uncond + self.guidance_scale * (
noise_pred_text - noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1
if not isinstance(self.scheduler.module, CogVideoXDPMScheduler):
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
else:
latents, old_pred_original_sample = self.scheduler.step(
noise_pred,
old_pred_original_sample,
t,
timesteps[i - 1] if i > 0 else None,
latents,
**extra_step_kwargs,
return_dict=False,
)
latents = latents.to(prompt_embeds.dtype)
# call the callback, if provided
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
if i == len(timesteps) - 1 or (
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
):
progress_bar.update()
if get_sequence_parallel_world_size() > 1:
latents = get_sp_group().all_gather(latents, dim=-2)
if is_dp_last_group():
if not output_type == "latent":
# Discard any padding frames that were added for CogVideoX 1.5
latents = latents[:, additional_frames:]
video = self.decode_latents(latents)
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
else:
video = latents
else:
video = [None for _ in range(batch_size)]
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return CogVideoXPipelineOutput(frames=video)
def _init_sync_pipeline(
self,
latents: torch.Tensor,
prompt_embeds: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
latents_frames: Optional[int] = None,
):
latents = super()._init_video_sync_pipeline(latents)
if get_runtime_state().split_text_embed_in_sp:
if prompt_embeds.shape[-2] % get_sequence_parallel_world_size() == 0:
prompt_embeds = torch.chunk(prompt_embeds, get_sequence_parallel_world_size(), dim=-2)[get_sequence_parallel_rank()]
else:
get_runtime_state().split_text_embed_in_sp = False
if image_rotary_emb is not None:
assert latents_frames is not None
d = image_rotary_emb[0].shape[-1]
image_rotary_emb = (
torch.cat(
[
image_rotary_emb[0]
.reshape(latents_frames, -1, d)[
:, start_token_idx:end_token_idx
]
.reshape(-1, d)
for start_token_idx, end_token_idx in get_runtime_state().pp_patches_token_start_end_idx_global
],
dim=0,
),
torch.cat(
[
image_rotary_emb[1]
.reshape(latents_frames, -1, d)[
:, start_token_idx:end_token_idx
]
.reshape(-1, d)
for start_token_idx, end_token_idx in get_runtime_state().pp_patches_token_start_end_idx_global
],
dim=0,
),
)
return latents, prompt_embeds, image_rotary_emb
@property
def interrupt(self):
return self._interrupt
@property
def guidance_scale(self):
return self._guidance_scale
# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any, Dict, List, Tuple, Callable, Optional, Union
import numpy as np
import torch
import torch.distributed
from diffusers import FluxPipeline
from diffusers.utils import is_torch_xla_available
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
from diffusers.pipelines.flux.pipeline_flux import retrieve_timesteps, calculate_shift
from xfuser.config import EngineConfig, InputConfig
from xfuser.core.distributed import (
get_pipeline_parallel_world_size,
get_runtime_state,
get_pp_group,
get_sequence_parallel_world_size,
get_sequence_parallel_rank,
get_sp_group,
is_pipeline_first_stage,
is_pipeline_last_stage,
is_dp_last_group,
)
from .base_pipeline import xFuserPipelineBaseWrapper
from .register import xFuserPipelineWrapperRegister
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
@xFuserPipelineWrapperRegister.register(FluxPipeline)
class xFuserFluxPipeline(xFuserPipelineBaseWrapper):
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
engine_config: EngineConfig,
**kwargs,
):
pipeline = FluxPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs)
return cls(pipeline, engine_config)
def prepare_run(
self,
input_config: InputConfig,
steps: int = 3,
sync_steps: int = 1,
):
prompt = [""] * input_config.batch_size if input_config.batch_size > 1 else ""
warmup_steps = get_runtime_state().runtime_config.warmup_steps
get_runtime_state().runtime_config.warmup_steps = sync_steps
self.__call__(
height=input_config.height,
width=input_config.width,
prompt=prompt,
num_inference_steps=steps,
max_sequence_length=input_config.max_sequence_length,
generator=torch.Generator(device="cuda").manual_seed(42),
output_type=input_config.output_type,
)
get_runtime_state().runtime_config.warmup_steps = warmup_steps
@property
def guidance_scale(self):
return self._guidance_scale
@property
def joint_attention_kwargs(self):
return self._joint_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@xFuserPipelineBaseWrapper.check_model_parallel_state(cfg_parallel_available=False)
@xFuserPipelineBaseWrapper.enable_data_parallel
@xFuserPipelineBaseWrapper.check_to_use_naive_forward
def __call__(
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 28,
timesteps: List[int] = None,
guidance_scale: float = 7.0,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
will be used instead
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 7.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
Examples:
Returns:
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
images.
"""
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
prompt_2,
height,
width,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
self._guidance_scale = guidance_scale
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
#! ---------------------------------------- ADDED BELOW ----------------------------------------
# * set runtime state input parameters
get_runtime_state().set_input_parameters(
height=height,
width=width,
batch_size=batch_size,
num_inference_steps=num_inference_steps,
max_condition_sequence_length=max_sequence_length,
split_text_embed_in_sp=get_pipeline_parallel_world_size() == 1,
)
#! ---------------------------------------- ADDED ABOVE ----------------------------------------
lora_scale = (
self.joint_attention_kwargs.get("scale", None)
if self.joint_attention_kwargs is not None
else None
)
(
prompt_embeds,
pooled_prompt_embeds,
text_ids,
) = self.encode_prompt(
prompt=prompt,
prompt_2=prompt_2,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4
latents, latent_image_ids = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
image_seq_len = latents.shape[1]
mu = calculate_shift(
image_seq_len,
self.scheduler.config.base_image_seq_len,
self.scheduler.config.max_image_seq_len,
self.scheduler.config.base_shift,
self.scheduler.config.max_shift,
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
timesteps,
sigmas,
mu=mu,
)
num_warmup_steps = max(
len(timesteps) - num_inference_steps * self.scheduler.order, 0
)
self._num_timesteps = len(timesteps)
# handle guidance
if self.transformer.config.guidance_embeds:
guidance = torch.full(
[1], guidance_scale, device=device, dtype=torch.float32
)
guidance = guidance.expand(latents.shape[0])
else:
guidance = None
num_pipeline_warmup_steps = get_runtime_state().runtime_config.warmup_steps
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
if (
get_pipeline_parallel_world_size() > 1
and len(timesteps) > num_pipeline_warmup_steps
):
# raise RuntimeError("Async pipeline not supported in flux")
latents = self._sync_pipeline(
latents=latents,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
text_ids=text_ids,
latent_image_ids=latent_image_ids,
guidance=guidance,
timesteps=timesteps[:num_pipeline_warmup_steps],
num_warmup_steps=num_warmup_steps,
progress_bar=progress_bar,
callback_on_step_end=callback_on_step_end,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
)
latents = self._async_pipeline(
latents=latents,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
text_ids=text_ids,
latent_image_ids=latent_image_ids,
guidance=guidance,
timesteps=timesteps[num_pipeline_warmup_steps:],
num_warmup_steps=num_warmup_steps,
progress_bar=progress_bar,
callback_on_step_end=callback_on_step_end,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
)
else:
latents = self._sync_pipeline(
latents=latents,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
text_ids=text_ids,
latent_image_ids=latent_image_ids,
guidance=guidance,
timesteps=timesteps,
num_warmup_steps=num_warmup_steps,
progress_bar=progress_bar,
callback_on_step_end=callback_on_step_end,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
sync_only=True,
)
def vae_decode(latents):
latents = self._unpack_latents(
latents, height, width, self.vae_scale_factor
)
latents = (
latents / self.vae.config.scaling_factor
) + self.vae.config.shift_factor
image = self.vae.decode(latents, return_dict=False)[0]
return image
if not output_type == "latent":
if get_runtime_state().runtime_config.use_parallel_vae:
latents = self.gather_broadcast_latents(latents)
image = vae_decode(latents)
else:
if is_dp_last_group():
image = vae_decode(latents)
if self.is_dp_last_group():
if output_type == "latent":
image = latents
else:
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return FluxPipelineOutput(images=image)
else:
return None
def _init_sync_pipeline(
self, latents: torch.Tensor, latent_image_ids: torch.Tensor,
prompt_embeds: torch.Tensor, text_ids: torch.Tensor
):
get_runtime_state().set_patched_mode(patch_mode=False)
latents_list = [
latents[:, start_idx:end_idx, :]
for start_idx, end_idx in get_runtime_state().pp_patches_token_start_end_idx_global
]
latents = torch.cat(latents_list, dim=-2)
latent_image_ids_list = [
latent_image_ids[start_idx:end_idx]
for start_idx, end_idx in get_runtime_state().pp_patches_token_start_end_idx_global
]
latent_image_ids = torch.cat(latent_image_ids_list, dim=-2)
if get_runtime_state().split_text_embed_in_sp:
if prompt_embeds.shape[-2] % get_sequence_parallel_world_size() == 0:
prompt_embeds = torch.chunk(prompt_embeds, get_sequence_parallel_world_size(), dim=-2)[get_sequence_parallel_rank()]
else:
get_runtime_state().split_text_embed_in_sp = False
if get_runtime_state().split_text_embed_in_sp:
if text_ids.shape[-2] % get_sequence_parallel_world_size() == 0:
text_ids = torch.chunk(text_ids, get_sequence_parallel_world_size(), dim=-2)[get_sequence_parallel_rank()]
else:
get_runtime_state().split_text_embed_in_sp = False
return latents, latent_image_ids, prompt_embeds, text_ids
# synchronized compute the whole feature map in each pp stage
def _sync_pipeline(
self,
latents: torch.Tensor,
prompt_embeds: torch.Tensor,
pooled_prompt_embeds: torch.Tensor,
text_ids: torch.Tensor,
latent_image_ids: torch.Tensor,
guidance,
timesteps: List[int],
num_warmup_steps: int,
progress_bar,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
sync_only: bool = False,
):
latents, latent_image_ids, prompt_embeds, text_ids = self._init_sync_pipeline(latents, latent_image_ids, prompt_embeds, text_ids)
for i, t in enumerate(timesteps):
if self.interrupt:
continue
if is_pipeline_last_stage():
last_timestep_latents = latents
# when there is only one pp stage, no need to recv
if get_pipeline_parallel_world_size() == 1:
pass
# all ranks should recv the latent from the previous rank except
# the first rank in the first pipeline forward which should use
# the input latent
elif is_pipeline_first_stage() and i == 0:
pass
else:
latents = get_pp_group().pipeline_recv()
if not is_pipeline_first_stage():
encoder_hidden_state = get_pp_group().pipeline_recv(
0, "encoder_hidden_state"
)
# # handle guidance
# if self.transformer.config.guidance_embeds:
# guidance = torch.tensor([guidance_scale], device=self._execution_device)
# guidance = guidance.expand(latents.shape[0])
# else:
# guidance = None
latents, encoder_hidden_state = self._backbone_forward(
latents=latents,
encoder_hidden_states=(
prompt_embeds if is_pipeline_first_stage() else encoder_hidden_state
),
pooled_prompt_embeds=pooled_prompt_embeds,
text_ids=text_ids,
latent_image_ids=latent_image_ids,
guidance=guidance,
t=t,
)
if is_pipeline_last_stage():
latents_dtype = latents.dtype
latents = self._scheduler_step(latents, last_timestep_latents, t)
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
if i == len(timesteps) - 1 or (
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if sync_only and is_pipeline_last_stage() and i == len(timesteps) - 1:
pass
elif get_pipeline_parallel_world_size() > 1:
get_pp_group().pipeline_send(latents)
if not is_pipeline_last_stage():
get_pp_group().pipeline_send(
encoder_hidden_state, name="encoder_hidden_state"
)
if (
sync_only
and get_sequence_parallel_world_size() > 1
and is_pipeline_last_stage()
):
sp_degree = get_sequence_parallel_world_size()
sp_latents_list = get_sp_group().all_gather(latents, separate_tensors=True)
latents_list = []
for pp_patch_idx in range(get_runtime_state().num_pipeline_patch):
latents_list += [
sp_latents_list[sp_patch_idx][
:,
get_runtime_state()
.pp_patches_token_start_idx_local[pp_patch_idx] : get_runtime_state()
.pp_patches_token_start_idx_local[pp_patch_idx + 1],
:,
]
for sp_patch_idx in range(sp_degree)
]
latents = torch.cat(latents_list, dim=-2)
return latents
def _async_pipeline(
self,
latents: torch.Tensor,
prompt_embeds: torch.Tensor,
pooled_prompt_embeds: torch.Tensor,
text_ids: torch.Tensor,
latent_image_ids: torch.Tensor,
guidance,
timesteps: List[int],
num_warmup_steps: int,
progress_bar,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
):
if len(timesteps) == 0:
return latents
num_pipeline_patch = get_runtime_state().num_pipeline_patch
num_pipeline_warmup_steps = get_runtime_state().runtime_config.warmup_steps
patch_latents, patch_latent_image_ids = self._init_async_pipeline(
num_timesteps=len(timesteps),
latents=latents,
num_pipeline_warmup_steps=num_pipeline_warmup_steps,
latent_image_ids=latent_image_ids,
)
last_patch_latents = (
[None for _ in range(num_pipeline_patch)]
if (is_pipeline_last_stage())
else None
)
first_async_recv = True
for i, t in enumerate(timesteps):
if self.interrupt:
continue
for patch_idx in range(num_pipeline_patch):
if is_pipeline_last_stage():
last_patch_latents[patch_idx] = patch_latents[patch_idx]
if is_pipeline_first_stage() and i == 0:
pass
else:
if first_async_recv:
if not is_pipeline_first_stage() and patch_idx == 0:
get_pp_group().recv_next()
get_pp_group().recv_next()
first_async_recv = False
if not is_pipeline_first_stage() and patch_idx == 0:
last_encoder_hidden_states = (
get_pp_group().get_pipeline_recv_data(
idx=patch_idx, name="encoder_hidden_states"
)
)
patch_latents[patch_idx] = get_pp_group().get_pipeline_recv_data(
idx=patch_idx
)
patch_latents[patch_idx], next_encoder_hidden_states = (
self._backbone_forward(
latents=patch_latents[patch_idx],
encoder_hidden_states=(
prompt_embeds
if is_pipeline_first_stage()
else last_encoder_hidden_states
),
pooled_prompt_embeds=pooled_prompt_embeds,
text_ids=text_ids,
latent_image_ids=patch_latent_image_ids[patch_idx],
guidance=guidance,
t=t,
)
)
if is_pipeline_last_stage():
latents_dtype = patch_latents[patch_idx].dtype
patch_latents[patch_idx] = self._scheduler_step(
patch_latents[patch_idx],
last_patch_latents[patch_idx],
t,
)
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(
self, i, t, callback_kwargs
)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop(
"prompt_embeds", prompt_embeds
)
negative_prompt_embeds = callback_outputs.pop(
"negative_prompt_embeds", negative_prompt_embeds
)
negative_pooled_prompt_embeds = callback_outputs.pop(
"negative_pooled_prompt_embeds",
negative_pooled_prompt_embeds,
)
if i != len(timesteps) - 1:
get_pp_group().pipeline_isend(
patch_latents[patch_idx], segment_idx=patch_idx
)
else:
if patch_idx == 0:
get_pp_group().pipeline_isend(
next_encoder_hidden_states, name="encoder_hidden_states"
)
get_pp_group().pipeline_isend(
patch_latents[patch_idx], segment_idx=patch_idx
)
if is_pipeline_first_stage() and i == 0:
pass
else:
if i == len(timesteps) - 1 and patch_idx == num_pipeline_patch - 1:
pass
elif is_pipeline_first_stage():
get_pp_group().recv_next()
else:
# recv encoder_hidden_state
if patch_idx == num_pipeline_patch - 1:
get_pp_group().recv_next()
# recv latents
get_pp_group().recv_next()
get_runtime_state().next_patch()
if i == len(timesteps) - 1 or (
(i + num_pipeline_warmup_steps + 1) > num_warmup_steps
and (i + num_pipeline_warmup_steps + 1) % self.scheduler.order == 0
):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
latents = None
if is_pipeline_last_stage():
latents = torch.cat(patch_latents, dim=-2)
if get_sequence_parallel_world_size() > 1:
sp_degree = get_sequence_parallel_world_size()
sp_latents_list = get_sp_group().all_gather(
latents, separate_tensors=True
)
latents_list = []
for pp_patch_idx in range(get_runtime_state().num_pipeline_patch):
latents_list += [
sp_latents_list[sp_patch_idx][
...,
get_runtime_state()
.pp_patches_token_start_idx_local[
pp_patch_idx
] : get_runtime_state()
.pp_patches_token_start_idx_local[pp_patch_idx + 1],
:,
]
for sp_patch_idx in range(sp_degree)
]
latents = torch.cat(latents_list, dim=-2)
return latents
def _init_async_pipeline(
self,
num_timesteps: int,
latents: torch.Tensor,
num_pipeline_warmup_steps: int,
latent_image_ids: torch.Tensor,
):
get_runtime_state().set_patched_mode(patch_mode=True)
if is_pipeline_first_stage():
# get latents computed in warmup stage
# ignore latents after the last timestep
latents = (
get_pp_group().pipeline_recv()
if num_pipeline_warmup_steps > 0
else latents
)
patch_latents = list(
latents.split(get_runtime_state().pp_patches_token_num, dim=-2)
)
elif is_pipeline_last_stage():
patch_latents = list(
latents.split(get_runtime_state().pp_patches_token_num, dim=-2)
)
else:
patch_latents = [
None for _ in range(get_runtime_state().num_pipeline_patch)
]
patch_latent_image_ids = list(
latent_image_ids[start_idx:end_idx]
for start_idx, end_idx in get_runtime_state().pp_patches_token_start_end_idx_global
)
recv_timesteps = (
num_timesteps - 1 if is_pipeline_first_stage() else num_timesteps
)
if is_pipeline_first_stage():
for _ in range(recv_timesteps):
for patch_idx in range(get_runtime_state().num_pipeline_patch):
get_pp_group().add_pipeline_recv_task(patch_idx)
else:
for _ in range(recv_timesteps):
get_pp_group().add_pipeline_recv_task(0, "encoder_hidden_states")
for patch_idx in range(get_runtime_state().num_pipeline_patch):
get_pp_group().add_pipeline_recv_task(patch_idx)
return patch_latents, patch_latent_image_ids
def _backbone_forward(
self,
latents: torch.Tensor,
encoder_hidden_states: torch.Tensor,
pooled_prompt_embeds: torch.Tensor,
text_ids,
latent_image_ids,
guidance,
t: Union[float, torch.Tensor],
):
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
noise_pred, encoder_hidden_states = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=encoder_hidden_states,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
return noise_pred, encoder_hidden_states
def _scheduler_step(
self,
noise_pred: torch.Tensor,
latents: torch.Tensor,
t: Union[float, torch.Tensor],
):
return self.scheduler.step(
noise_pred,
t,
latents,
return_dict=False,
)[0]
import os
from typing import Callable, Dict, List, Tuple, Callable, Optional, Union
import torch
import torch.distributed
from diffusers import HunyuanDiTPipeline
from diffusers.pipelines.hunyuandit.pipeline_hunyuandit import (
SUPPORTED_SHAPE,
map_to_standard_shapes,
get_resize_crop_region_for_grid,
rescale_noise_cfg,
)
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
from diffusers.models.embeddings import get_2d_rotary_pos_embed
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from xfuser.config import EngineConfig
from xfuser.logger import init_logger
from xfuser.core.distributed import (
get_classifier_free_guidance_world_size,
get_pipeline_parallel_world_size,
get_runtime_state,
get_pipeline_parallel_rank,
get_cfg_group,
get_pp_group,
get_sequence_parallel_world_size,
get_sp_group,
is_dp_last_group,
is_pipeline_last_stage,
is_pipeline_first_stage,
get_world_group
)
from xfuser.model_executor.pipelines import xFuserPipelineBaseWrapper
from .register import xFuserPipelineWrapperRegister
logger = init_logger(__name__)
@xFuserPipelineWrapperRegister.register(HunyuanDiTPipeline)
class xFuserHunyuanDiTPipeline(xFuserPipelineBaseWrapper):
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
engine_config: EngineConfig,
**kwargs,
):
pipeline = HunyuanDiTPipeline.from_pretrained(
pretrained_model_name_or_path, **kwargs
)
return cls(pipeline, engine_config)
@property
def guidance_scale(self):
return self._guidance_scale
@property
def guidance_rescale(self):
return self._guidance_rescale
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@xFuserPipelineBaseWrapper.enable_data_parallel
@xFuserPipelineBaseWrapper.check_to_use_naive_forward
def __call__(
self,
prompt: Union[str, List[str]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_embeds_2: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds_2: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
prompt_attention_mask_2: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask_2: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback_on_step_end: Optional[
Union[
Callable[[int, int, Dict], None],
PipelineCallback,
MultiPipelineCallbacks,
]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
guidance_rescale: float = 0.0,
original_size: Optional[Tuple[int, int]] = (1024, 1024),
target_size: Optional[Tuple[int, int]] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
use_resolution_binning: bool = True,
**kwargs,
):
r"""
The call function to the pipeline for generation with HunyuanDiT.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
height (`int`):
The height in pixels of the generated image.
width (`int`):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. This parameter is modulated by `strength`.
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
prompt_embeds_2 (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
negative_prompt_embeds_2 (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
prompt_attention_mask (`torch.Tensor`, *optional*):
Attention mask for the prompt. Required when `prompt_embeds` is passed directly.
prompt_attention_mask_2 (`torch.Tensor`, *optional*):
Attention mask for the prompt. Required when `prompt_embeds_2` is passed directly.
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly.
negative_prompt_attention_mask_2 (`torch.Tensor`, *optional*):
Attention mask for the negative prompt. Required when `negative_prompt_embeds_2` is passed directly.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback_on_step_end (`Callable[[int, int, Dict], None]`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
A callback function or a list of callback functions to be called at the end of each denoising step.
callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
A list of tensor inputs that should be passed to the callback function. If not defined, all tensor
inputs will be passed.
guidance_rescale (`float`, *optional*, defaults to 0.0):
Rescale the noise_cfg according to `guidance_rescale`. Based on findings of [Common Diffusion Noise
Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
original_size (`Tuple[int, int]`, *optional*, defaults to `(1024, 1024)`):
The original size of the image. Used to calculate the time ids.
target_size (`Tuple[int, int]`, *optional*):
The target size of the image. Used to calculate the time ids.
crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to `(0, 0)`):
The top left coordinates of the crop. Used to calculate the time ids.
use_resolution_binning (`bool`, *optional*, defaults to `True`):
Whether to use resolution binning or not. If `True`, the input resolution will be mapped to the closest
standard resolution. Supported resolutions are 1024x1024, 1280x1280, 1024x768, 1152x864, 1280x960,
768x1024, 864x1152, 960x1280, 1280x768, and 768x1280. It is recommended to set this to `True`.
Examples:
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
otherwise a `tuple` is returned where the first element is a list with the generated images and the
second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content.
"""
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# 0. default height and width
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
height = int((height // 16) * 16)
width = int((width // 16) * 16)
if use_resolution_binning and (height, width) not in SUPPORTED_SHAPE:
width, height = map_to_standard_shapes(width, height)
height = int(height)
width = int(width)
logger.warning(
f"Reshaped to (height, width)=({height}, {width}), Supported shapes are {SUPPORTED_SHAPE}"
)
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
height,
width,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
prompt_attention_mask,
negative_prompt_attention_mask,
prompt_embeds_2,
negative_prompt_embeds_2,
prompt_attention_mask_2,
negative_prompt_attention_mask_2,
callback_on_step_end_tensor_inputs,
)
self._guidance_scale = guidance_scale
self._guidance_rescale = guidance_rescale
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
#! ---------------------------------------- ADDED BELOW ----------------------------------------
# * set runtime state input parameters
get_runtime_state().set_input_parameters(
height=height,
width=width,
batch_size=batch_size,
num_inference_steps=num_inference_steps,
)
if get_pipeline_parallel_rank() >= get_pipeline_parallel_world_size() // 2:
num_blocks_per_stage = len(self.transformer.blocks)
get_runtime_state()._reset_recv_skip_buffer(num_blocks_per_stage)
#! ---------------------------------------- ADDED ABOVE ----------------------------------------
# 3. Encode input prompt
(
prompt_embeds,
negative_prompt_embeds,
prompt_attention_mask,
negative_prompt_attention_mask,
) = self.encode_prompt(
prompt=prompt,
device=device,
dtype=self.transformer.dtype,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
max_sequence_length=77,
text_encoder_index=0,
)
(
prompt_embeds_2,
negative_prompt_embeds_2,
prompt_attention_mask_2,
negative_prompt_attention_mask_2,
) = self.encode_prompt(
prompt=prompt,
device=device,
dtype=self.transformer.dtype,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds_2,
negative_prompt_embeds=negative_prompt_embeds_2,
prompt_attention_mask=prompt_attention_mask_2,
negative_prompt_attention_mask=negative_prompt_attention_mask_2,
max_sequence_length=256,
text_encoder_index=1,
)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7 create image_rotary_emb, style embedding & time ids
grid_height = height // 8 // self.transformer.config.patch_size
grid_width = width // 8 // self.transformer.config.patch_size
base_size = 512 // 8 // self.transformer.config.patch_size
grid_crops_coords = get_resize_crop_region_for_grid(
(grid_height, grid_width), base_size
)
image_rotary_emb = get_2d_rotary_pos_embed(
self.transformer.inner_dim // self.transformer.num_heads,
grid_crops_coords,
(grid_height, grid_width),
)
style = torch.tensor([0], device=device)
target_size = target_size or (height, width)
add_time_ids = list(original_size + target_size + crops_coords_top_left)
add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype)
#! ---------------------------------------- MODIFIED BELOW ----------------------------------------
# * dealing with cfg degree
if self.do_classifier_free_guidance:
(
prompt_embeds,
prompt_attention_mask,
) = self._process_cfg_split_batch(
negative_prompt_embeds,
prompt_embeds,
negative_prompt_attention_mask,
prompt_attention_mask,
)
(
prompt_embeds_2,
prompt_attention_mask_2,
) = self._process_cfg_split_batch(
negative_prompt_embeds_2,
prompt_embeds_2,
negative_prompt_attention_mask_2,
prompt_attention_mask_2,
)
if get_classifier_free_guidance_world_size() == 1:
add_time_ids = torch.cat([add_time_ids] * 2, dim=0)
style = torch.cat([style] * 2, dim=0)
#! ORIGIN
# if self.do_classifier_free_guidance:
# prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
# prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])
# prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])
# prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2])
# add_time_ids = torch.cat([add_time_ids] * 2, dim=0)
# style = torch.cat([style] * 2, dim=0)
#! ---------------------------------------- MODIFIED ABOVE ----------------------------------------
prompt_embeds = prompt_embeds.to(device=device)
prompt_attention_mask = prompt_attention_mask.to(device=device)
prompt_embeds_2 = prompt_embeds_2.to(device=device)
prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device)
add_time_ids = add_time_ids.to(dtype=prompt_embeds.dtype, device=device).repeat(
batch_size * num_images_per_prompt, 1
)
style = style.to(device=device).repeat(batch_size * num_images_per_prompt)
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
#! ---------------------------------------- MODIFIED BELOW ----------------------------------------
num_pipeline_warmup_steps = get_runtime_state().runtime_config.warmup_steps
with self.progress_bar(total=num_inference_steps) as progress_bar:
if (
get_pipeline_parallel_world_size() > 1
and len(timesteps) > num_pipeline_warmup_steps
):
# * warmup stage
latents = self._sync_pipeline(
latents=latents,
prompt_embeds=prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
prompt_embeds_2=prompt_embeds_2,
prompt_attention_mask_2=prompt_attention_mask_2,
add_time_ids=add_time_ids,
style=style,
image_rotary_emb=image_rotary_emb,
device=device,
guidance_scale=guidance_scale,
guidance_rescale=guidance_rescale,
timesteps=timesteps[:num_pipeline_warmup_steps],
num_warmup_steps=num_warmup_steps,
extra_step_kwargs=extra_step_kwargs,
progress_bar=progress_bar,
callback_on_step_end=callback_on_step_end,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
)
# * pipefusion stage
latents = self._async_pipeline(
latents=latents,
prompt_embeds=prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
prompt_embeds_2=prompt_embeds_2,
prompt_attention_mask_2=prompt_attention_mask_2,
add_time_ids=add_time_ids,
style=style,
image_rotary_emb=image_rotary_emb,
device=device,
guidance_scale=guidance_scale,
guidance_rescale=guidance_rescale,
timesteps=timesteps[num_pipeline_warmup_steps:],
num_warmup_steps=num_warmup_steps,
extra_step_kwargs=extra_step_kwargs,
progress_bar=progress_bar,
callback_on_step_end=callback_on_step_end,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
)
else:
latents = self._sync_pipeline(
latents=latents,
prompt_embeds=prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
prompt_embeds_2=prompt_embeds_2,
prompt_attention_mask_2=prompt_attention_mask_2,
add_time_ids=add_time_ids,
style=style,
image_rotary_emb=image_rotary_emb,
device=device,
guidance_scale=guidance_scale,
guidance_rescale=guidance_rescale,
timesteps=timesteps,
num_warmup_steps=num_warmup_steps,
extra_step_kwargs=extra_step_kwargs,
progress_bar=progress_bar,
callback_on_step_end=callback_on_step_end,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
sync_only=True,
)
#! ---------------------------------------- MODIFIED ABOVE ----------------------------------------
# 8. Decode latents (only rank 0)
#! ---------------------------------------- ADD BELOW ----------------------------------------
def vae_decode(latents):
image = self.vae.decode(
latents / self.vae.config.scaling_factor, return_dict=False
)[0]
return image
if not output_type == "latent":
if get_runtime_state().runtime_config.use_parallel_vae:
latents = self.gather_broadcast_latents(latents)
image = vae_decode(latents)
else:
if is_dp_last_group():
image = vae_decode(latents)
if self.is_dp_last_group():
#! ---------------------------------------- ADD ABOVE ----------------------------------------
if not output_type == "latent":
image, has_nsfw_concept = self.run_safety_checker(
image, device, prompt_embeds.dtype
)
else:
image = latents
has_nsfw_concept = None
if has_nsfw_concept is None:
do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
image = self.image_processor.postprocess(
image, output_type=output_type, do_denormalize=do_denormalize
)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(
images=image, nsfw_content_detected=has_nsfw_concept
)
#! ---------------------------------------- ADD BELOW ----------------------------------------
else:
return None
#! ---------------------------------------- ADD ABOVE ----------------------------------------
def _init_sync_pipeline(self, latents: torch.Tensor, image_rotary_emb):
latents = super()._init_sync_pipeline(latents)
image_rotary_emb = (
torch.cat(
[
image_rotary_emb[0][start_token_idx:end_token_idx, ...]
for start_token_idx, end_token_idx in get_runtime_state().pp_patches_token_start_end_idx_global
],
dim=0,
),
torch.cat(
[
image_rotary_emb[1][start_token_idx:end_token_idx, ...]
for start_token_idx, end_token_idx in get_runtime_state().pp_patches_token_start_end_idx_global
],
dim=0,
),
)
return latents, image_rotary_emb
def _init_async_pipeline(
self,
num_timesteps: int,
latents: torch.Tensor,
num_pipeline_warmup_steps: int,
):
patch_latents = super()._init_async_pipeline(
num_timesteps,
latents,
num_pipeline_warmup_steps,
)
if get_pipeline_parallel_rank() >= get_pipeline_parallel_world_size() // 2:
for _ in range(num_timesteps):
for patch_idx in range(get_runtime_state().num_pipeline_patch):
get_pp_group().add_pipeline_recv_skip_task(patch_idx)
return patch_latents
# synchronized compute the whole feature map in each pp stage
def _sync_pipeline(
self,
latents: torch.Tensor,
prompt_embeds: torch.Tensor,
prompt_attention_mask: torch.Tensor,
prompt_embeds_2: torch.Tensor,
prompt_attention_mask_2: torch.Tensor,
add_time_ids: torch.Tensor,
style: torch.Tensor,
image_rotary_emb: Tuple[torch.Tensor, torch.Tensor],
device: torch.device,
guidance_scale: float,
guidance_rescale: float,
timesteps: List[int],
num_warmup_steps: int,
extra_step_kwargs: List,
progress_bar,
callback_on_step_end: Optional[
Union[
Callable[[int, int, Dict], None],
PipelineCallback,
MultiPipelineCallbacks,
]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
sync_only: bool = False,
):
latents, image_rotary_emb = self._init_sync_pipeline(latents, image_rotary_emb)
skips = None
for i, t in enumerate(timesteps):
if is_pipeline_last_stage():
last_timestep_latents = latents
# when there is only one pp stage, no need to recv
if get_pipeline_parallel_world_size() == 1:
pass
# all ranks should recv the latent from the previous rank except
# the first rank in the first pipeline forward which should use
# the input latent
elif is_pipeline_first_stage() and i == 0:
pass
else:
latents = get_pp_group().pipeline_recv()
if (
get_pipeline_parallel_rank()
>= get_pipeline_parallel_world_size() // 2
):
skips = get_pp_group().pipeline_recv_skip()
latents = self._backbone_forward(
latents=latents,
prompt_embeds=prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
prompt_embeds_2=prompt_embeds_2,
prompt_attention_mask_2=prompt_attention_mask_2,
add_time_ids=add_time_ids,
style=style,
image_rotary_emb=image_rotary_emb,
t=t,
device=device,
guidance_scale=guidance_scale,
guidance_rescale=guidance_rescale,
skips=skips,
)
if is_pipeline_last_stage():
latents = self.scheduler.step(
latents,
t,
last_timestep_latents,
**extra_step_kwargs,
return_dict=False,
)[0]
elif (
get_pipeline_parallel_rank() >= get_pipeline_parallel_world_size() // 2
):
pass
else:
latents, skips = latents
if i == len(timesteps) - 1 or (
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
):
progress_bar.update()
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop(
"negative_prompt_embeds", negative_prompt_embeds
)
prompt_embeds_2 = callback_outputs.pop(
"prompt_embeds_2", prompt_embeds_2
)
negative_prompt_embeds_2 = callback_outputs.pop(
"negative_prompt_embeds_2", negative_prompt_embeds_2
)
if sync_only and is_pipeline_last_stage() and i == len(timesteps) - 1:
pass
elif get_pipeline_parallel_world_size() > 1:
get_pp_group().pipeline_send(latents)
if (
get_pipeline_parallel_rank()
< get_pipeline_parallel_world_size() // 2
):
get_pp_group().pipeline_send_skip(skips)
if (
sync_only
and get_sequence_parallel_world_size() > 1
and is_pipeline_last_stage()
):
sp_degree = get_sequence_parallel_world_size()
sp_latents_list = get_sp_group().all_gather(latents, separate_tensors=True)
latents_list = []
for pp_patch_idx in range(get_runtime_state().num_pipeline_patch):
latents_list += [
sp_latents_list[sp_patch_idx][
:,
:,
get_runtime_state()
.pp_patches_start_idx_local[pp_patch_idx] : get_runtime_state()
.pp_patches_start_idx_local[pp_patch_idx + 1],
:,
]
for sp_patch_idx in range(sp_degree)
]
latents = torch.cat(latents_list, dim=-2)
return latents
# * implement of pipefusion
def _async_pipeline(
self,
latents: torch.Tensor,
prompt_embeds: torch.Tensor,
prompt_attention_mask: torch.Tensor,
prompt_embeds_2: torch.Tensor,
prompt_attention_mask_2: torch.Tensor,
add_time_ids: torch.Tensor,
style: torch.Tensor,
image_rotary_emb: Tuple[torch.Tensor, torch.Tensor],
device: torch.device,
guidance_scale: float,
guidance_rescale: float,
timesteps: List[int],
num_warmup_steps: int,
extra_step_kwargs: List,
progress_bar,
callback_on_step_end: Optional[
Union[
Callable[[int, int, Dict], None],
PipelineCallback,
MultiPipelineCallbacks,
]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
):
if len(timesteps) == 0:
return latents
num_pipeline_patch = get_runtime_state().num_pipeline_patch
num_pipeline_warmup_steps = get_runtime_state().runtime_config.warmup_steps
patch_latents = self._init_async_pipeline(
num_timesteps=len(timesteps),
latents=latents,
num_pipeline_warmup_steps=num_pipeline_warmup_steps,
)
full_image_rotary_emb = image_rotary_emb
last_patch_latents = (
[None for _ in range(num_pipeline_patch)]
if (is_pipeline_last_stage())
else None
)
first_async_recv = True
skips = None
for i, t in enumerate(timesteps):
for patch_idx in range(num_pipeline_patch):
start_token_idx, end_token_idx = (
get_runtime_state().pp_patches_token_start_end_idx_global[patch_idx]
)
image_rotary_emb = (
full_image_rotary_emb[0][start_token_idx:end_token_idx, :],
full_image_rotary_emb[1][start_token_idx:end_token_idx, :],
)
if is_pipeline_last_stage():
last_patch_latents[patch_idx] = patch_latents[patch_idx]
if is_pipeline_first_stage() and i == 0:
pass
else:
if first_async_recv:
get_pp_group().recv_next()
if (
get_pipeline_parallel_rank()
>= get_pipeline_parallel_world_size() // 2
):
get_pp_group().recv_skip_next()
first_async_recv = False
patch_latents[patch_idx] = get_pp_group().get_pipeline_recv_data(
idx=patch_idx
)
if (
get_pipeline_parallel_rank()
>= get_pipeline_parallel_world_size() // 2
):
skips = get_pp_group().get_pipeline_recv_skip_data(
idx=patch_idx
)
patch_latents[patch_idx] = self._backbone_forward(
latents=patch_latents[patch_idx],
prompt_embeds=prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
prompt_embeds_2=prompt_embeds_2,
prompt_attention_mask_2=prompt_attention_mask_2,
add_time_ids=add_time_ids,
style=style,
image_rotary_emb=image_rotary_emb,
t=t,
device=device,
guidance_scale=guidance_scale,
guidance_rescale=guidance_rescale,
skips=skips,
)
if is_pipeline_last_stage():
patch_latents[patch_idx] = self.scheduler.step(
patch_latents[patch_idx],
t,
last_patch_latents[patch_idx],
**extra_step_kwargs,
return_dict=False,
)[0]
if i != len(timesteps) - 1:
get_pp_group().pipeline_isend(
patch_latents[patch_idx], segment_idx=patch_idx
)
elif (
get_pipeline_parallel_rank()
>= get_pipeline_parallel_world_size() // 2
):
get_pp_group().pipeline_isend(
patch_latents[patch_idx], segment_idx=patch_idx
)
else:
patch_latents[patch_idx], skips = patch_latents[patch_idx]
get_pp_group().pipeline_isend(
patch_latents[patch_idx], segment_idx=patch_idx
)
get_pp_group().pipeline_isend_skip(skips)
if is_pipeline_first_stage() and i == 0:
pass
else:
if i == len(timesteps) - 1 and patch_idx == num_pipeline_patch - 1:
pass
else:
get_pp_group().recv_next()
if (
get_pipeline_parallel_rank()
>= get_pipeline_parallel_world_size() // 2
):
get_pp_group().recv_skip_next()
get_runtime_state().next_patch()
if i == len(timesteps) - 1 or (
(i + num_pipeline_warmup_steps + 1) > num_warmup_steps
and (i + num_pipeline_warmup_steps + 1) % self.scheduler.order == 0
):
progress_bar.update()
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop(
"negative_prompt_embeds", negative_prompt_embeds
)
prompt_embeds_2 = callback_outputs.pop(
"prompt_embeds_2", prompt_embeds_2
)
negative_prompt_embeds_2 = callback_outputs.pop(
"negative_prompt_embeds_2", negative_prompt_embeds_2
)
latents = None
if is_pipeline_last_stage():
latents = torch.cat(patch_latents, dim=2)
if get_sequence_parallel_world_size() > 1:
sp_degree = get_sequence_parallel_world_size()
sp_latents_list = get_sp_group().all_gather(
latents, separate_tensors=True
)
latents_list = []
for pp_patch_idx in range(get_runtime_state().num_pipeline_patch):
latents_list += [
sp_latents_list[sp_patch_idx][
...,
get_runtime_state()
.pp_patches_start_idx_local[
pp_patch_idx
] : get_runtime_state()
.pp_patches_start_idx_local[pp_patch_idx + 1],
:,
]
for sp_patch_idx in range(sp_degree)
]
latents = torch.cat(latents_list, dim=-2)
return latents
def _backbone_forward(
self,
latents: torch.FloatTensor,
prompt_embeds: torch.FloatTensor,
prompt_attention_mask: torch.FloatTensor,
prompt_embeds_2: torch.FloatTensor,
prompt_attention_mask_2: torch.FloatTensor,
add_time_ids: torch.Tensor,
style: torch.Tensor,
image_rotary_emb: Tuple[torch.Tensor, torch.Tensor],
t: Union[float, torch.Tensor],
device: torch.device,
guidance_scale: float,
guidance_rescale: float,
skips: torch.FloatTensor,
):
if is_pipeline_first_stage():
if self.do_classifier_free_guidance:
latents = torch.cat(
[latents] * (2 // get_classifier_free_guidance_world_size())
)
latents = self.scheduler.scale_model_input(latents, t)
# expand scalar t to 1-D tensor to match the 1st dim of latents
t_expand = torch.tensor([t] * latents.shape[0], device=device).to(
dtype=latents.dtype
)
# predict the noise residual
noise_pred = self.transformer(
latents,
t_expand,
encoder_hidden_states=prompt_embeds,
text_embedding_mask=prompt_attention_mask,
encoder_hidden_states_t5=prompt_embeds_2,
text_embedding_mask_t5=prompt_attention_mask_2,
image_meta_size=add_time_ids,
style=style,
image_rotary_emb=image_rotary_emb,
skips=skips,
return_dict=False,
)[0]
if is_pipeline_last_stage():
noise_pred, _ = noise_pred.chunk(2, dim=1)
# perform guidance
if self.do_classifier_free_guidance:
if get_classifier_free_guidance_world_size() == 1:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
elif get_classifier_free_guidance_world_size() == 2:
noise_pred_uncond, noise_pred_text = get_cfg_group().all_gather(
noise_pred, separate_tensors=True
)
latents = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
if self.do_classifier_free_guidance and guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
latents = rescale_noise_cfg(
latents, noise_pred_text, guidance_rescale=guidance_rescale
)
else:
latents = noise_pred
return latents
import os
from typing import List, Tuple, Callable, Optional, Union, Dict
import torch
import torch.distributed
from diffusers import LattePipeline
from diffusers.pipelines.latte.pipeline_latte import (
LattePipelineOutput,
retrieve_timesteps,
)
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
from diffusers.utils import deprecate
from xfuser.config import EngineConfig
from xfuser.core.distributed import (
get_data_parallel_world_size,
get_classifier_free_guidance_world_size,
get_pipeline_parallel_world_size,
get_data_parallel_rank,
get_runtime_state,
is_pipeline_first_stage,
)
from xfuser.core.distributed import (
get_data_parallel_world_size,
get_sequence_parallel_world_size,
get_pipeline_parallel_world_size,
get_classifier_free_guidance_world_size,
get_classifier_free_guidance_rank,
get_pipeline_parallel_rank,
get_pp_group,
get_world_group,
get_cfg_group,
get_sp_group,
get_runtime_state,
initialize_runtime_state,
is_dp_last_group,
)
from xfuser.model_executor.pipelines import xFuserPipelineBaseWrapper
from .register import xFuserPipelineWrapperRegister
@xFuserPipelineWrapperRegister.register(LattePipeline)
class xFuserLattePipeline(xFuserPipelineBaseWrapper):
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
engine_config: EngineConfig,
**kwargs,
):
pipeline = LattePipeline.from_pretrained(
pretrained_model_name_or_path, **kwargs
)
return cls(pipeline, engine_config)
@torch.no_grad()
@xFuserPipelineBaseWrapper.enable_data_parallel
@xFuserPipelineBaseWrapper.check_to_use_naive_forward
def __call__(
self,
prompt: Union[str, List[str]] = None,
negative_prompt: str = "",
num_inference_steps: int = 50,
timesteps: Optional[List[int]] = None,
guidance_scale: float = 7.5,
num_images_per_prompt: int = 1,
video_length: int = 16,
height: int = 512,
width: int = 512,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: str = "pil",
return_dict: bool = True,
callback_on_step_end: Optional[
Union[
Callable[[int, int, Dict], None],
PipelineCallback,
MultiPipelineCallbacks,
]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
clean_caption: bool = True,
mask_feature: bool = True,
enable_temporal_attentions: bool = True,
decode_chunk_size: Optional[int] = None,
num_pipeline_warmup_steps: Optional[int] = 3,
**kwargs,
) -> Union[LattePipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`.
instead.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the video generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
num_inference_steps (`int`, *optional*, defaults to 100):
The number of denoising steps. More denoising steps usually lead to a higher quality video at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
timesteps are used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 7.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate videos that are closely linked to the text `prompt`,
usually at the expense of lower video quality.
video_length (`int`, *optional*, defaults to 16):
The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of videos to generate per prompt.
height (`int`, *optional*, defaults to self.unet.config.sample_size):
The height in pixels of the generated video.
width (`int`, *optional*, defaults to self.unet.config.sample_size):
The width in pixels of the generated video.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. For Latte this negative prompt should be "". If not provided,
negative_prompt_embeds will be generated from `negative_prompt` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate video. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
callback_on_step_end (`Callable[[int, int, Dict], None]`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
A callback function or a list of callback functions to be called at the end of each denoising step.
callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
A list of tensor inputs that should be passed to the callback function. If not defined, all tensor
inputs will be passed.
clean_caption (`bool`, *optional*, defaults to `True`):
Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
be installed. If the dependencies are not installed, the embeddings will be created from the raw
prompt.
mask_feature (`bool` defaults to `True`): If set to `True`, the text embeddings will be masked.
enable_temporal_attentions (`bool`, *optional*, defaults to `True`): Whether to enable temporal attentions
decode_chunk_size (`int`, *optional*):
The number of frames to decode at a time. Higher chunk size leads to better temporal consistency at the
expense of more memory usage. By default, the decoder decodes all frames at once for maximal quality.
For lower memory usage, reduce `decode_chunk_size`.
Examples:
Returns:
[`~pipelines.latte.pipeline_latte.LattePipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.latte.pipeline_latte.LattePipelineOutput`] is returned,
otherwise a `tuple` is returned where the first element is a list with the generated images
"""
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# 0. Default
num_frames = video_length
decode_chunk_size = (
decode_chunk_size if decode_chunk_size is not None else num_frames
)
# 1. Check inputs. Raise error if not correct
height = height or self.transformer.config.sample_size * self.vae_scale_factor
width = width or self.transformer.config.sample_size * self.vae_scale_factor
self.check_inputs(
prompt,
height,
width,
negative_prompt,
callback_on_step_end_tensor_inputs,
prompt_embeds,
negative_prompt_embeds,
)
self._guidance_scale = guidance_scale
self._interrupt = False
# 2. Default height and width to transformer
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# * set runtime state input parameters
get_runtime_state().set_video_input_parameters(
height=height,
width=width,
num_frames=num_frames,
batch_size=batch_size,
num_inference_steps=num_inference_steps,
)
# 3. Encode input prompt
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
do_classifier_free_guidance,
negative_prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
clean_caption=clean_caption,
mask_feature=mask_feature,
)
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps
)
self._num_timesteps = len(timesteps)
# 5. Prepare latents.
latent_channels = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
latent_channels,
num_frames,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Denoising loop
num_warmup_steps = max(
len(timesteps) - num_inference_steps * self.scheduler.order, 0
)
with self.progress_bar(total=num_inference_steps) as progress_bar:
latents = self._init_video_sync_pipeline(latents)
for i, t in enumerate(timesteps):
if self.interrupt:
continue
latent_model_input = (
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
)
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, t
)
current_timestep = t
if not torch.is_tensor(current_timestep):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = latent_model_input.device.type == "mps"
if isinstance(current_timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
current_timestep = torch.tensor(
[current_timestep],
dtype=dtype,
device=latent_model_input.device,
)
elif len(current_timestep.shape) == 0:
current_timestep = current_timestep[None].to(
latent_model_input.device
)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
current_timestep = current_timestep.expand(latent_model_input.shape[0])
noise_pred = self.transformer(
latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=current_timestep,
enable_temporal_attentions=enable_temporal_attentions,
return_dict=False,
)[0]
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
# use learned sigma?
if not (
hasattr(self.scheduler.config, "variance_type")
and self.scheduler.config.variance_type
in ["learned", "learned_range"]
):
noise_pred = noise_pred.chunk(2, dim=1)[0]
# compute previous video: x_t -> x_t-1
latents = self.scheduler.step(
noise_pred, t, latents, **extra_step_kwargs, return_dict=False
)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop(
"negative_prompt_embeds", negative_prompt_embeds
)
if i == len(timesteps) - 1 or (
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
):
progress_bar.update()
if get_sequence_parallel_world_size() > 1:
sp_degree = get_sequence_parallel_world_size()
sp_latents_list = get_sp_group().all_gather(latents, separate_tensors=True)
latents_list = []
for pp_patch_idx in range(get_runtime_state().num_pipeline_patch):
latents_list += [
sp_latents_list[sp_patch_idx][
:,
:,
:,
get_runtime_state()
.pp_patches_start_idx_local[pp_patch_idx] : get_runtime_state()
.pp_patches_start_idx_local[pp_patch_idx + 1],
:,
]
for sp_patch_idx in range(sp_degree)
]
latents = torch.cat(latents_list, dim=-2)
if is_dp_last_group():
if not (output_type == "latents" or output_type == "latent"):
video = self.decode_latents(latents, num_frames, decode_chunk_size=14)
video = self.video_processor.postprocess_video(
video=video, output_type=output_type
)
else:
video = latents
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return LattePipelineOutput(frames=video)
@property
def interrupt(self):
return self._interrupt
import os
from typing import Dict, List, Tuple, Callable, Optional, Union
import torch
import torch.distributed
from diffusers import PixArtAlphaPipeline
from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha import (
ASPECT_RATIO_256_BIN,
ASPECT_RATIO_512_BIN,
ASPECT_RATIO_1024_BIN,
)
from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha import retrieve_timesteps
from diffusers.utils import deprecate
from diffusers.pipelines.pipeline_utils import ImagePipelineOutput
from xfuser.config import EngineConfig
from xfuser.core.distributed import (
is_dp_last_group,
get_classifier_free_guidance_world_size,
get_pipeline_parallel_world_size,
get_runtime_state,
get_cfg_group,
get_pp_group,
get_sequence_parallel_world_size,
get_sp_group,
is_pipeline_first_stage,
is_pipeline_last_stage,
get_world_group
)
from xfuser.model_executor.pipelines import xFuserPipelineBaseWrapper
from .register import xFuserPipelineWrapperRegister
@xFuserPipelineWrapperRegister.register(PixArtAlphaPipeline)
class xFuserPixArtAlphaPipeline(xFuserPipelineBaseWrapper):
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
engine_config: EngineConfig,
**kwargs,
):
pipeline = PixArtAlphaPipeline.from_pretrained(
pretrained_model_name_or_path, **kwargs
)
return cls(pipeline, engine_config)
@torch.no_grad()
@xFuserPipelineBaseWrapper.enable_fast_attn
@xFuserPipelineBaseWrapper.enable_data_parallel
@xFuserPipelineBaseWrapper.check_to_use_naive_forward
def __call__(
self,
prompt: Union[str, List[str]] = None,
negative_prompt: str = "",
num_inference_steps: int = 20,
timesteps: List[int] = None,
sigmas: List[float] = None,
guidance_scale: float = 4.5,
num_images_per_prompt: Optional[int] = 1,
height: Optional[int] = None,
width: Optional[int] = None,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
callback_steps: int = 1,
clean_caption: bool = True,
use_resolution_binning: bool = True,
max_sequence_length: int = 120,
**kwargs,
) -> Union[ImagePipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
num_inference_steps (`int`, *optional*, defaults to 100):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 4.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
height (`int`, *optional*, defaults to self.unet.config.sample_size):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size):
The width in pixels of the generated image.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
Pre-generated attention mask for negative text embeddings.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
clean_caption (`bool`, *optional*, defaults to `True`):
Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
be installed. If the dependencies are not installed, the embeddings will be created from the raw
prompt.
use_resolution_binning (`bool` defaults to `True`):
If set to `True`, the requested height and width are first mapped to the closest resolutions using
`ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to
the requested resolution. Useful for generating non-square images.
max_sequence_length (`int` defaults to 120): Maximum sequence length to use with the `prompt`.
Examples:
Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
returned where the first element is a list with the generated images
"""
if "mask_feature" in kwargs:
deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
# 1. Check inputs. Raise error if not correct
height = height or self.transformer.config.sample_size * self.vae_scale_factor
width = width or self.transformer.config.sample_size * self.vae_scale_factor
if use_resolution_binning:
if self.transformer.config.sample_size == 128:
aspect_ratio_bin = ASPECT_RATIO_1024_BIN
elif self.transformer.config.sample_size == 64:
aspect_ratio_bin = ASPECT_RATIO_512_BIN
elif self.transformer.config.sample_size == 32:
aspect_ratio_bin = ASPECT_RATIO_256_BIN
else:
raise ValueError("Invalid sample size")
orig_height, orig_width = height, width
height, width = self.image_processor.classify_height_width_bin(
height, width, ratios=aspect_ratio_bin
)
self.check_inputs(
prompt,
height,
width,
negative_prompt,
callback_steps,
prompt_embeds,
negative_prompt_embeds,
prompt_attention_mask,
negative_prompt_attention_mask,
)
# 2. Default height and width to transformer
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
#! ---------------------------------------- ADDED BELOW ----------------------------------------
# * set runtime state input parameters
get_runtime_state().set_input_parameters(
height=height,
width=width,
batch_size=batch_size,
num_inference_steps=num_inference_steps,
)
#! ---------------------------------------- ADDED ABOVE ----------------------------------------
# 3. Encode input prompt
(
prompt_embeds,
prompt_attention_mask,
negative_prompt_embeds,
negative_prompt_attention_mask,
) = self.encode_prompt(
prompt,
do_classifier_free_guidance,
negative_prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
clean_caption=clean_caption,
max_sequence_length=max_sequence_length,
)
#! ---------------------------------------- MODIFIED BELOW ----------------------------------------
# * dealing with cfg degree
if do_classifier_free_guidance:
(
prompt_embeds,
prompt_attention_mask,
) = self._process_cfg_split_batch(
negative_prompt_embeds,
prompt_embeds,
negative_prompt_attention_mask,
prompt_attention_mask,
)
#! ORIGIN
# if do_classifier_free_guidance:
# prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
# prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
#! ---------------------------------------- MODIFIED ABOVE ----------------------------------------
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, sigmas
)
# 5. Prepare latents.
latent_channels = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
latent_channels,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 6.1 Prepare micro-conditions.
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
if self.transformer.config.sample_size == 128:
resolution = torch.tensor([height, width]).repeat(
batch_size * num_images_per_prompt, 1
)
aspect_ratio = torch.tensor([float(height / width)]).repeat(
batch_size * num_images_per_prompt, 1
)
resolution = resolution.to(dtype=prompt_embeds.dtype, device=device)
aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device)
#! ---------------------------------------- MODIFIED BELOW ----------------------------------------
if (
do_classifier_free_guidance
and get_classifier_free_guidance_world_size() == 1
):
resolution = torch.cat([resolution, resolution], dim=0)
aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0)
#! ORIGIN
# if do_classifier_free_guidance:
# resolution = torch.cat([resolution, resolution], dim=0)
# aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0)
#! ---------------------------------------- MODIFIED ABOVE ----------------------------------------
added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
# 7. Denoising loop
num_warmup_steps = max(
len(timesteps) - num_inference_steps * self.scheduler.order, 0
)
#! ---------------------------------------- MODIFIED BELOW ----------------------------------------
num_pipeline_warmup_steps = get_runtime_state().runtime_config.warmup_steps
with self.progress_bar(total=num_inference_steps) as progress_bar:
if (
get_pipeline_parallel_world_size() > 1
and len(timesteps) > num_pipeline_warmup_steps
):
# * warmup stage
latents = self._sync_pipeline(
latents=latents,
prompt_embeds=prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
guidance_scale=guidance_scale,
timesteps=timesteps[:num_pipeline_warmup_steps],
num_warmup_steps=num_warmup_steps,
extra_step_kwargs=extra_step_kwargs,
added_cond_kwargs=added_cond_kwargs,
progress_bar=progress_bar,
callback=callback,
callback_steps=callback_steps,
)
# * pipefusion stage
latents = self._async_pipeline(
latents=latents,
prompt_embeds=prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
guidance_scale=guidance_scale,
timesteps=timesteps[num_pipeline_warmup_steps:],
num_warmup_steps=num_warmup_steps,
extra_step_kwargs=extra_step_kwargs,
added_cond_kwargs=added_cond_kwargs,
progress_bar=progress_bar,
callback=callback,
callback_steps=callback_steps,
)
else:
latents = self._sync_pipeline(
latents=latents,
prompt_embeds=prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
guidance_scale=guidance_scale,
timesteps=timesteps,
num_warmup_steps=num_warmup_steps,
extra_step_kwargs=extra_step_kwargs,
added_cond_kwargs=added_cond_kwargs,
progress_bar=progress_bar,
callback=callback,
callback_steps=callback_steps,
sync_only=True,
)
#! ---------------------------------------- MODIFIED ABOVE ----------------------------------------
# 8. Decode latents (only rank 0)
#! ---------------------------------------- ADD BELOW ----------------------------------------
def vae_decode(latents):
image = self.vae.decode(
latents / self.vae.config.scaling_factor, return_dict=False
)[0]
return image
if not output_type == "latent":
if get_runtime_state().runtime_config.use_parallel_vae:
latents = self.gather_broadcast_latents(latents)
image = vae_decode(latents)
else:
if is_dp_last_group():
image = vae_decode(latents)
if self.is_dp_last_group():
#! ---------------------------------------- ADD ABOVE ----------------------------------------
if not output_type == "latent":
if use_resolution_binning:
image = self.image_processor.resize_and_crop_tensor(
image, orig_width, orig_height
)
else:
image = latents
if not output_type == "latent":
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
#! ---------------------------------------- ADD BELOW ----------------------------------------
else:
return None
#! ---------------------------------------- ADD ABOVE ----------------------------------------
def _scheduler_step(
self,
noise_pred: torch.Tensor,
latents: torch.Tensor,
t: Union[float, torch.Tensor],
extra_step_kwargs: Dict,
):
# compute previous image: x_t -> x_t-1
return self.scheduler.step(
noise_pred,
t,
latents,
**extra_step_kwargs,
return_dict=False,
)[0]
# synchronized compute the whole feature map in each pp stage
def _sync_pipeline(
self,
latents: torch.Tensor,
prompt_embeds: torch.Tensor,
prompt_attention_mask: torch.Tensor,
guidance_scale: float,
timesteps: List[int],
num_warmup_steps: int,
extra_step_kwargs: List,
added_cond_kwargs: Dict,
progress_bar,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
sync_only: bool = False,
):
latents = self._init_sync_pipeline(latents)
for i, t in enumerate(timesteps):
if is_pipeline_last_stage():
last_timestep_latents = latents
# when there is only one pp stage, no need to recv
if get_pipeline_parallel_world_size() == 1:
pass
# all ranks should recv the latent from the previous rank except
# the first rank in the first pipeline forward which should use
# the input latent
elif is_pipeline_first_stage() and i == 0:
pass
else:
latents = get_pp_group().pipeline_recv()
latents = self._backbone_forward(
latents=latents,
prompt_embeds=prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
added_cond_kwargs=added_cond_kwargs,
t=t,
guidance_scale=guidance_scale,
)
if is_pipeline_last_stage():
latents = self._scheduler_step(
latents, last_timestep_latents, t, extra_step_kwargs
)
if i == len(timesteps) - 1 or (
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if sync_only and is_pipeline_last_stage() and i == len(timesteps) - 1:
pass
elif get_pipeline_parallel_world_size() > 1:
get_pp_group().pipeline_send(latents)
if (
sync_only
and get_sequence_parallel_world_size() > 1
and is_pipeline_last_stage()
):
sp_degree = get_sequence_parallel_world_size()
sp_latents_list = get_sp_group().all_gather(latents, separate_tensors=True)
latents_list = []
for pp_patch_idx in range(get_runtime_state().num_pipeline_patch):
latents_list += [
sp_latents_list[sp_patch_idx][
:,
:,
get_runtime_state()
.pp_patches_start_idx_local[pp_patch_idx] : get_runtime_state()
.pp_patches_start_idx_local[pp_patch_idx + 1],
:,
]
for sp_patch_idx in range(sp_degree)
]
latents = torch.cat(latents_list, dim=-2)
return latents
# * implement of pipefusion
def _async_pipeline(
self,
latents: torch.Tensor,
prompt_embeds: torch.Tensor,
prompt_attention_mask: torch.Tensor,
guidance_scale: float,
timesteps: List[int],
num_warmup_steps: int,
extra_step_kwargs: List,
added_cond_kwargs: Dict,
progress_bar,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
):
if len(timesteps) == 0:
return latents
num_pipeline_patch = get_runtime_state().num_pipeline_patch
num_pipeline_warmup_steps = get_runtime_state().runtime_config.warmup_steps
patch_latents = self._init_async_pipeline(
num_timesteps=len(timesteps),
latents=latents,
num_pipeline_warmup_steps=num_pipeline_warmup_steps,
)
last_patch_latents = (
[None for _ in range(num_pipeline_patch)]
if (is_pipeline_last_stage())
else None
)
first_async_recv = True
for i, t in enumerate(timesteps):
for patch_idx in range(num_pipeline_patch):
if is_pipeline_last_stage():
last_patch_latents[patch_idx] = patch_latents[patch_idx]
if is_pipeline_first_stage() and i == 0:
pass
else:
if first_async_recv:
get_pp_group().recv_next()
first_async_recv = False
patch_latents[patch_idx] = get_pp_group().get_pipeline_recv_data(
idx=patch_idx
)
patch_latents[patch_idx] = self._backbone_forward(
latents=patch_latents[patch_idx],
prompt_embeds=prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
added_cond_kwargs=added_cond_kwargs,
t=t,
guidance_scale=guidance_scale,
)
if is_pipeline_last_stage():
patch_latents[patch_idx] = self._scheduler_step(
patch_latents[patch_idx],
last_patch_latents[patch_idx],
t,
extra_step_kwargs,
)
if i != len(timesteps) - 1:
get_pp_group().pipeline_isend(
patch_latents[patch_idx], segment_idx=patch_idx
)
else:
get_pp_group().pipeline_isend(
patch_latents[patch_idx], segment_idx=patch_idx
)
if is_pipeline_first_stage() and i == 0:
pass
else:
if i == len(timesteps) - 1 and patch_idx == num_pipeline_patch - 1:
pass
else:
get_pp_group().recv_next()
get_runtime_state().next_patch()
if i == len(timesteps) - 1 or (
(i + num_pipeline_warmup_steps + 1) > num_warmup_steps
and (i + num_pipeline_warmup_steps + 1) % self.scheduler.order == 0
):
progress_bar.update()
assert callback is None, "callback not supported in async " "pipeline"
if (
callback is not None
and i + num_pipeline_warmup_steps % callback_steps == 0
):
step_idx = (i + num_pipeline_warmup_steps) // getattr(
self.scheduler, "order", 1
)
callback(step_idx, t, patch_latents[patch_idx])
latents = None
if is_pipeline_last_stage():
latents = torch.cat(patch_latents, dim=2)
if get_sequence_parallel_world_size() > 1:
sp_degree = get_sequence_parallel_world_size()
sp_latents_list = get_sp_group().all_gather(
latents, separate_tensors=True
)
latents_list = []
for pp_patch_idx in range(get_runtime_state().num_pipeline_patch):
latents_list += [
sp_latents_list[sp_patch_idx][
...,
get_runtime_state()
.pp_patches_start_idx_local[
pp_patch_idx
] : get_runtime_state()
.pp_patches_start_idx_local[pp_patch_idx + 1],
:,
]
for sp_patch_idx in range(sp_degree)
]
latents = torch.cat(latents_list, dim=-2)
return latents
def _backbone_forward(
self,
latents: torch.Tensor,
prompt_embeds: torch.Tensor,
prompt_attention_mask: torch.Tensor,
added_cond_kwargs: Dict,
t: Union[float, torch.Tensor],
guidance_scale: float,
):
if is_pipeline_first_stage():
latents = torch.cat(
[latents] * (2 // get_classifier_free_guidance_world_size())
)
latents = self.scheduler.scale_model_input(latents, t)
current_timestep = t
if not torch.is_tensor(current_timestep):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = latents.device.type == "mps"
if isinstance(current_timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
current_timestep = torch.tensor(
[current_timestep], dtype=dtype, device=latents.device
)
elif len(current_timestep.shape) == 0:
current_timestep = current_timestep[None].to(latents.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
current_timestep = current_timestep.expand(latents.shape[0])
noise_pred = self.transformer(
latents,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask,
timestep=current_timestep,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
# classifier free guidance
if is_pipeline_last_stage():
if get_classifier_free_guidance_world_size() == 1:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
elif get_classifier_free_guidance_world_size() == 2:
noise_pred_uncond, noise_pred_text = get_cfg_group().all_gather(
noise_pred, separate_tensors=True
)
latents = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
if (
self.transformer.config.out_channels // 2
== self.transformer.config.in_channels
):
latents = latents.chunk(2, dim=1)[0]
else:
latents = noise_pred
return latents
import os
from typing import Dict, List, Tuple, Callable, Optional, Union
import torch
import torch.distributed
from diffusers import PixArtSigmaPipeline
from diffusers.pipelines.pixart_alpha.pipeline_pixart_sigma import (
ASPECT_RATIO_256_BIN,
ASPECT_RATIO_512_BIN,
ASPECT_RATIO_1024_BIN,
ASPECT_RATIO_2048_BIN,
retrieve_timesteps,
)
from diffusers.pipelines.pipeline_utils import ImagePipelineOutput
from xfuser.config import EngineConfig
from xfuser.core.distributed import (
is_dp_last_group,
get_classifier_free_guidance_world_size,
get_pipeline_parallel_world_size,
get_runtime_state,
get_cfg_group,
get_pp_group,
get_sequence_parallel_world_size,
get_sp_group,
is_pipeline_first_stage,
is_pipeline_last_stage,
get_world_group
)
from .base_pipeline import xFuserPipelineBaseWrapper
from .register import xFuserPipelineWrapperRegister
@xFuserPipelineWrapperRegister.register(PixArtSigmaPipeline)
class xFuserPixArtSigmaPipeline(xFuserPipelineBaseWrapper):
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
engine_config: EngineConfig,
**kwargs,
):
pipeline = PixArtSigmaPipeline.from_pretrained(
pretrained_model_name_or_path, **kwargs
)
return cls(pipeline, engine_config)
@torch.no_grad()
@xFuserPipelineBaseWrapper.enable_fast_attn
@xFuserPipelineBaseWrapper.enable_data_parallel
@xFuserPipelineBaseWrapper.check_to_use_naive_forward
def __call__(
self,
prompt: Union[str, List[str]] = None,
negative_prompt: str = "",
num_inference_steps: int = 20,
timesteps: List[int] = None,
sigmas: List[float] = None,
guidance_scale: float = 4.5,
num_images_per_prompt: Optional[int] = 1,
height: Optional[int] = None,
width: Optional[int] = None,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
callback_steps: int = 1,
clean_caption: bool = True,
use_resolution_binning: bool = True,
max_sequence_length: int = 300,
**kwargs,
) -> Union[ImagePipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
num_inference_steps (`int`, *optional*, defaults to 100):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 4.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
height (`int`, *optional*, defaults to self.unet.config.sample_size):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size):
The width in pixels of the generated image.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
Pre-generated attention mask for negative text embeddings.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
clean_caption (`bool`, *optional*, defaults to `True`):
Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
be installed. If the dependencies are not installed, the embeddings will be created from the raw
prompt.
use_resolution_binning (`bool` defaults to `True`):
If set to `True`, the requested height and width are first mapped to the closest resolutions using
`ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to
the requested resolution. Useful for generating non-square images.
max_sequence_length (`int` defaults to 300): Maximum sequence length to use with the `prompt`.
Examples:
Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
returned where the first element is a list with the generated images
"""
# * check pp world size
# 1. Check inputs. Raise error if not correct
height = height or self.transformer.config.sample_size * self.vae_scale_factor
width = width or self.transformer.config.sample_size * self.vae_scale_factor
if use_resolution_binning:
if self.transformer.config.sample_size == 256:
aspect_ratio_bin = ASPECT_RATIO_2048_BIN
elif self.transformer.config.sample_size == 128:
aspect_ratio_bin = ASPECT_RATIO_1024_BIN
elif self.transformer.config.sample_size == 64:
aspect_ratio_bin = ASPECT_RATIO_512_BIN
elif self.transformer.config.sample_size == 32:
aspect_ratio_bin = ASPECT_RATIO_256_BIN
else:
raise ValueError("Invalid sample size")
orig_height, orig_width = height, width
height, width = self.image_processor.classify_height_width_bin(
height, width, ratios=aspect_ratio_bin
)
self.check_inputs(
prompt,
height,
width,
negative_prompt,
callback_steps,
prompt_embeds,
negative_prompt_embeds,
prompt_attention_mask,
negative_prompt_attention_mask,
)
# 2. Default height and width to transformer
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# * set runtime state input parameters
get_runtime_state().set_input_parameters(
height=height,
width=width,
batch_size=batch_size,
num_inference_steps=num_inference_steps,
)
# 3. Encode input prompt
(
prompt_embeds,
prompt_attention_mask,
negative_prompt_embeds,
negative_prompt_attention_mask,
) = self.encode_prompt(
prompt,
do_classifier_free_guidance,
negative_prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
clean_caption=clean_caption,
max_sequence_length=max_sequence_length,
)
# * dealing with cfg degree
if do_classifier_free_guidance:
(
prompt_embeds,
prompt_attention_mask,
) = self._process_cfg_split_batch(
negative_prompt_embeds,
prompt_embeds,
negative_prompt_attention_mask,
prompt_attention_mask,
)
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, sigmas
)
# 5. Prepare latents.
latent_channels = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
latent_channels,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 6.1 Prepare micro-conditions.
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
# 7. Denoising loop
num_warmup_steps = max(
len(timesteps) - num_inference_steps * self.scheduler.order, 0
)
num_pipeline_warmup_steps = get_runtime_state().runtime_config.warmup_steps
with self.progress_bar(total=num_inference_steps) as progress_bar:
if (
get_pipeline_parallel_world_size() > 1
and len(timesteps) > num_pipeline_warmup_steps
):
# * warmup stage
latents = self._sync_pipeline(
latents=latents,
prompt_embeds=prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
guidance_scale=guidance_scale,
timesteps=timesteps[:num_pipeline_warmup_steps],
num_warmup_steps=num_warmup_steps,
extra_step_kwargs=extra_step_kwargs,
added_cond_kwargs=added_cond_kwargs,
progress_bar=progress_bar,
callback=callback,
callback_steps=callback_steps,
)
# * pipefusion stage
latents = self._async_pipeline(
latents=latents,
prompt_embeds=prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
guidance_scale=guidance_scale,
timesteps=timesteps[num_pipeline_warmup_steps:],
num_warmup_steps=num_warmup_steps,
extra_step_kwargs=extra_step_kwargs,
added_cond_kwargs=added_cond_kwargs,
progress_bar=progress_bar,
callback=callback,
callback_steps=callback_steps,
)
else:
latents = self._sync_pipeline(
latents=latents,
prompt_embeds=prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
guidance_scale=guidance_scale,
timesteps=timesteps,
num_warmup_steps=num_warmup_steps,
extra_step_kwargs=extra_step_kwargs,
added_cond_kwargs=added_cond_kwargs,
progress_bar=progress_bar,
callback=callback,
callback_steps=callback_steps,
sync_only=True,
)
# * 8. Decode latents (only the last rank in a dp group)
def vae_decode(latents):
image = self.vae.decode(
latents / self.vae.config.scaling_factor, return_dict=False
)[0]
return image
if not output_type == "latent":
if get_runtime_state().runtime_config.use_parallel_vae:
latents = self.gather_broadcast_latents(latents)
image = vae_decode(latents)
else:
if is_dp_last_group():
image = vae_decode(latents)
if self.is_dp_last_group():
if not output_type == "latent":
if use_resolution_binning:
image = self.image_processor.resize_and_crop_tensor(
image, orig_width, orig_height
)
else:
image = latents
if not output_type == "latent":
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
else:
return None
def _scheduler_step(
self,
noise_pred: torch.Tensor,
latents: torch.Tensor,
t: Union[float, torch.Tensor],
extra_step_kwargs: Dict,
):
# compute previous image: x_t -> x_t-1
return self.scheduler.step(
noise_pred,
t,
latents,
**extra_step_kwargs,
return_dict=False,
)[0]
# synchronized compute the whole feature map in each pp stage
def _sync_pipeline(
self,
latents: torch.Tensor,
prompt_embeds: torch.Tensor,
prompt_attention_mask: torch.Tensor,
guidance_scale: float,
timesteps: List[int],
num_warmup_steps: int,
extra_step_kwargs: List,
added_cond_kwargs: Dict,
progress_bar,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
sync_only: bool = False,
):
latents = self._init_sync_pipeline(latents)
for i, t in enumerate(timesteps):
if is_pipeline_last_stage():
last_timestep_latents = latents
# when there is only one pp stage, no need to recv
if get_pipeline_parallel_world_size() == 1:
pass
# all ranks should recv the latent from the previous rank except
# the first rank in the first pipeline forward which should use
# the input latent
elif is_pipeline_first_stage() and i == 0:
pass
else:
latents = get_pp_group().pipeline_recv()
latents = self._backbone_forward(
latents=latents,
prompt_embeds=prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
added_cond_kwargs=added_cond_kwargs,
t=t,
guidance_scale=guidance_scale,
)
if is_pipeline_last_stage():
latents = self._scheduler_step(
latents, last_timestep_latents, t, extra_step_kwargs
)
if i == len(timesteps) - 1 or (
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if sync_only and is_pipeline_last_stage() and i == len(timesteps) - 1:
pass
elif get_pipeline_parallel_world_size() > 1:
get_pp_group().pipeline_send(latents)
if (
sync_only
and get_sequence_parallel_world_size() > 1
and is_pipeline_last_stage()
):
sp_degree = get_sequence_parallel_world_size()
sp_latents_list = get_sp_group().all_gather(latents, separate_tensors=True)
latents_list = []
for pp_patch_idx in range(get_runtime_state().num_pipeline_patch):
latents_list += [
sp_latents_list[sp_patch_idx][
:,
:,
get_runtime_state()
.pp_patches_start_idx_local[pp_patch_idx] : get_runtime_state()
.pp_patches_start_idx_local[pp_patch_idx + 1],
:,
]
for sp_patch_idx in range(sp_degree)
]
latents = torch.cat(latents_list, dim=-2)
return latents
# * implement of pipefusion
def _async_pipeline(
self,
latents: torch.Tensor,
prompt_embeds: torch.Tensor,
prompt_attention_mask: torch.Tensor,
guidance_scale: float,
timesteps: List[int],
num_warmup_steps: int,
extra_step_kwargs: List,
added_cond_kwargs: Dict,
progress_bar,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
):
if len(timesteps) == 0:
return latents
num_pipeline_patch = get_runtime_state().num_pipeline_patch
num_pipeline_warmup_steps = get_runtime_state().runtime_config.warmup_steps
patch_latents = self._init_async_pipeline(
num_timesteps=len(timesteps),
latents=latents,
num_pipeline_warmup_steps=num_pipeline_warmup_steps,
)
last_patch_latents = (
[None for _ in range(num_pipeline_patch)]
if (is_pipeline_last_stage())
else None
)
first_async_recv = True
for i, t in enumerate(timesteps):
for patch_idx in range(num_pipeline_patch):
if is_pipeline_last_stage():
last_patch_latents[patch_idx] = patch_latents[patch_idx]
if is_pipeline_first_stage() and i == 0:
pass
else:
if first_async_recv:
get_pp_group().recv_next()
first_async_recv = False
patch_latents[patch_idx] = get_pp_group().get_pipeline_recv_data(
idx=patch_idx
)
patch_latents[patch_idx] = self._backbone_forward(
latents=patch_latents[patch_idx],
prompt_embeds=prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
added_cond_kwargs=added_cond_kwargs,
t=t,
guidance_scale=guidance_scale,
)
if is_pipeline_last_stage():
patch_latents[patch_idx] = self._scheduler_step(
patch_latents[patch_idx],
last_patch_latents[patch_idx],
t,
extra_step_kwargs,
)
if i != len(timesteps) - 1:
get_pp_group().pipeline_isend(
patch_latents[patch_idx], segment_idx=patch_idx
)
else:
get_pp_group().pipeline_isend(
patch_latents[patch_idx], segment_idx=patch_idx
)
if is_pipeline_first_stage() and i == 0:
pass
else:
if i == len(timesteps) - 1 and patch_idx == num_pipeline_patch - 1:
pass
else:
get_pp_group().recv_next()
get_runtime_state().next_patch()
if i == len(timesteps) - 1 or (
(i + num_pipeline_warmup_steps + 1) > num_warmup_steps
and (i + num_pipeline_warmup_steps + 1) % self.scheduler.order == 0
):
progress_bar.update()
assert callback is None, "callback not supported in async " "pipeline"
if (
callback is not None
and i + num_pipeline_warmup_steps % callback_steps == 0
):
step_idx = (i + num_pipeline_warmup_steps) // getattr(
self.scheduler, "order", 1
)
callback(step_idx, t, patch_latents[patch_idx])
latents = None
if is_pipeline_last_stage():
latents = torch.cat(patch_latents, dim=2)
if get_sequence_parallel_world_size() > 1:
sp_degree = get_sequence_parallel_world_size()
sp_latents_list = get_sp_group().all_gather(
latents, separate_tensors=True
)
latents_list = []
for pp_patch_idx in range(get_runtime_state().num_pipeline_patch):
latents_list += [
sp_latents_list[sp_patch_idx][
...,
get_runtime_state()
.pp_patches_start_idx_local[
pp_patch_idx
] : get_runtime_state()
.pp_patches_start_idx_local[pp_patch_idx + 1],
:,
]
for sp_patch_idx in range(sp_degree)
]
latents = torch.cat(latents_list, dim=-2)
return latents
def _backbone_forward(
self,
latents: torch.Tensor,
prompt_embeds: torch.Tensor,
prompt_attention_mask: torch.Tensor,
added_cond_kwargs: Dict,
t: Union[float, torch.Tensor],
guidance_scale: float,
):
if is_pipeline_first_stage():
latents = torch.cat(
[latents] * (2 // get_classifier_free_guidance_world_size())
)
latents = self.scheduler.scale_model_input(latents, t)
current_timestep = t
if not torch.is_tensor(current_timestep):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = latents.device.type == "mps"
if isinstance(current_timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
current_timestep = torch.tensor(
[current_timestep], dtype=dtype, device=latents.device
)
elif len(current_timestep.shape) == 0:
current_timestep = current_timestep[None].to(latents.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
current_timestep = current_timestep.expand(latents.shape[0])
noise_pred = self.transformer(
latents,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask,
timestep=current_timestep,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
# classifier free guidance
if is_pipeline_last_stage():
if get_classifier_free_guidance_world_size() == 1:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
elif get_classifier_free_guidance_world_size() == 2:
noise_pred_uncond, noise_pred_text = get_cfg_group().all_gather(
noise_pred, separate_tensors=True
)
latents = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
if (
self.transformer.config.out_channels // 2
== self.transformer.config.in_channels
):
latents = latents.chunk(2, dim=1)[0]
else:
latents = noise_pred
return latents
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment