Unverified Commit c2916175 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

Flux followup (#9074)

* refactor rotary embeds

* adding jsmidt as co-author of this PR for https://github.com/huggingface/diffusers/pull/9133



---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarJoseph Smidt <josephsmidt@gmail.com>
parent 9003d75f
...@@ -1695,81 +1695,6 @@ class FusedAuraFlowAttnProcessor2_0: ...@@ -1695,81 +1695,6 @@ class FusedAuraFlowAttnProcessor2_0:
return hidden_states return hidden_states
# YiYi to-do: refactor rope related functions/classes
def apply_rope(xq, xk, freqs_cis):
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
class FluxSingleAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
# YiYi to-do: update uising apply_rotary_emb
# from ..embeddings import apply_rotary_emb
# query = apply_rotary_emb(query, image_rotary_emb)
# key = apply_rotary_emb(key, image_rotary_emb)
query, key = apply_rope(query, key, image_rotary_emb)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
return hidden_states
class FluxAttnProcessor2_0: class FluxAttnProcessor2_0:
"""Attention processor used typically in processing the SD3-like self-attention projections.""" """Attention processor used typically in processing the SD3-like self-attention projections."""
...@@ -1785,16 +1710,7 @@ class FluxAttnProcessor2_0: ...@@ -1785,16 +1710,7 @@ class FluxAttnProcessor2_0:
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
input_ndim = hidden_states.ndim batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
context_input_ndim = encoder_hidden_states.ndim
if context_input_ndim == 4:
batch_size, channel, height, width = encoder_hidden_states.shape
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size = encoder_hidden_states.shape[0]
# `sample` projections. # `sample` projections.
query = attn.to_q(hidden_states) query = attn.to_q(hidden_states)
...@@ -1813,59 +1729,58 @@ class FluxAttnProcessor2_0: ...@@ -1813,59 +1729,58 @@ class FluxAttnProcessor2_0:
if attn.norm_k is not None: if attn.norm_k is not None:
key = attn.norm_k(key) key = attn.norm_k(key)
# `context` projections. # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) if encoder_hidden_states is not None:
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) # `context` projections.
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim batch_size, -1, attn.heads, head_dim
).transpose(1, 2) ).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim batch_size, -1, attn.heads, head_dim
).transpose(1, 2) ).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim batch_size, -1, attn.heads, head_dim
).transpose(1, 2) ).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
# attention if attn.norm_added_q is not None:
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) if attn.norm_added_k is not None:
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None: if image_rotary_emb is not None:
# YiYi to-do: update uising apply_rotary_emb from .embeddings import apply_rotary_emb
# from ..embeddings import apply_rotary_emb
# query = apply_rotary_emb(query, image_rotary_emb) query = apply_rotary_emb(query, image_rotary_emb)
# key = apply_rotary_emb(key, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb)
query, key = apply_rope(query, key, image_rotary_emb)
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype) hidden_states = hidden_states.to(query.dtype)
encoder_hidden_states, hidden_states = ( if encoder_hidden_states is not None:
hidden_states[:, : encoder_hidden_states.shape[1]], encoder_hidden_states, hidden_states = (
hidden_states[:, encoder_hidden_states.shape[1] :], hidden_states[:, : encoder_hidden_states.shape[1]],
) hidden_states[:, encoder_hidden_states.shape[1] :],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
if input_ndim == 4: # linear proj
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) hidden_states = attn.to_out[0](hidden_states)
if context_input_ndim == 4: # dropout
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states return hidden_states, encoder_hidden_states
else:
return hidden_states
class XFormersAttnAddedKVProcessor: class XFormersAttnAddedKVProcessor:
...@@ -4105,6 +4020,17 @@ class LoRAAttnAddedKVProcessor: ...@@ -4105,6 +4020,17 @@ class LoRAAttnAddedKVProcessor:
pass pass
class FluxSingleAttnProcessor2_0(FluxAttnProcessor2_0):
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
def __init__(self):
deprecation_message = "`FluxSingleAttnProcessor2_0` is deprecated and will be removed in a future version. Please use `FluxAttnProcessor2_0` instead."
deprecate("FluxSingleAttnProcessor2_0", "0.32.0", deprecation_message)
super().__init__()
ADDED_KV_ATTENTION_PROCESSORS = ( ADDED_KV_ATTENTION_PROCESSORS = (
AttnAddedKVProcessor, AttnAddedKVProcessor,
SlicedAttnAddedKVProcessor, SlicedAttnAddedKVProcessor,
......
...@@ -24,9 +24,9 @@ from ..models.attention_processor import AttentionProcessor ...@@ -24,9 +24,9 @@ from ..models.attention_processor import AttentionProcessor
from ..models.modeling_utils import ModelMixin from ..models.modeling_utils import ModelMixin
from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from .controlnet import BaseOutput, zero_module from .controlnet import BaseOutput, zero_module
from .embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings from .embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
from .modeling_outputs import Transformer2DModelOutput from .modeling_outputs import Transformer2DModelOutput
from .transformers.transformer_flux import EmbedND, FluxSingleTransformerBlock, FluxTransformerBlock from .transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -59,7 +59,7 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -59,7 +59,7 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
self.out_channels = in_channels self.out_channels = in_channels
self.inner_dim = num_attention_heads * attention_head_dim self.inner_dim = num_attention_heads * attention_head_dim
self.pos_embed = EmbedND(dim=self.inner_dim, theta=10000, axes_dim=axes_dims_rope) self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
text_time_guidance_cls = ( text_time_guidance_cls = (
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
) )
...@@ -272,8 +272,20 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -272,8 +272,20 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
) )
encoder_hidden_states = self.context_embedder(encoder_hidden_states) encoder_hidden_states = self.context_embedder(encoder_hidden_states)
txt_ids = txt_ids.expand(img_ids.size(0), -1, -1) if txt_ids.ndim == 3:
ids = torch.cat((txt_ids, img_ids), dim=1) logger.warning(
"Passing `txt_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
txt_ids = txt_ids[0]
if img_ids.ndim == 3:
logger.warning(
"Passing `img_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
img_ids = img_ids[0]
ids = torch.cat((txt_ids, img_ids), dim=0)
image_rotary_emb = self.pos_embed(ids) image_rotary_emb = self.pos_embed(ids)
block_samples = () block_samples = ()
......
...@@ -446,6 +446,7 @@ def get_1d_rotary_pos_embed( ...@@ -446,6 +446,7 @@ def get_1d_rotary_pos_embed(
linear_factor=1.0, linear_factor=1.0,
ntk_factor=1.0, ntk_factor=1.0,
repeat_interleave_real=True, repeat_interleave_real=True,
freqs_dtype=torch.float32, # torch.float32 (hunyuan, stable audio), torch.float64 (flux)
): ):
""" """
Precompute the frequency tensor for complex exponentials (cis) with given dimensions. Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
...@@ -468,6 +469,8 @@ def get_1d_rotary_pos_embed( ...@@ -468,6 +469,8 @@ def get_1d_rotary_pos_embed(
repeat_interleave_real (`bool`, *optional*, defaults to `True`): repeat_interleave_real (`bool`, *optional*, defaults to `True`):
If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`. If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
Otherwise, they are concateanted with themselves. Otherwise, they are concateanted with themselves.
freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
the dtype of the frequency tensor.
Returns: Returns:
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
""" """
...@@ -476,19 +479,19 @@ def get_1d_rotary_pos_embed( ...@@ -476,19 +479,19 @@ def get_1d_rotary_pos_embed(
if isinstance(pos, int): if isinstance(pos, int):
pos = np.arange(pos) pos = np.arange(pos)
theta = theta * ntk_factor theta = theta * ntk_factor
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) / linear_factor # [D/2] freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor # [D/2]
t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S] t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2] freqs = torch.outer(t, freqs) # type: ignore # [S, D/2]
if use_real and repeat_interleave_real: if use_real and repeat_interleave_real:
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
return freqs_cos, freqs_sin return freqs_cos, freqs_sin
elif use_real: elif use_real:
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1) # [S, D] freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1) # [S, D] freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
return freqs_cos, freqs_sin return freqs_cos, freqs_sin
else: else:
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] freqs_cis = torch.polar(torch.ones_like(freqs), freqs).float() # complex64 # [S, D/2]
return freqs_cis return freqs_cis
...@@ -540,6 +543,31 @@ def apply_rotary_emb( ...@@ -540,6 +543,31 @@ def apply_rotary_emb(
return x_out.type_as(x) return x_out.type_as(x)
class FluxPosEmbed(nn.Module):
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
def __init__(self, theta: int, axes_dim: List[int]):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
cos_out = []
sin_out = []
pos = ids.squeeze().float().cpu().numpy()
is_mps = ids.device.type == "mps"
freqs_dtype = torch.float32 if is_mps else torch.float64
for i in range(n_axes):
cos, sin = get_1d_rotary_pos_embed(
self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype
)
cos_out.append(cos)
sin_out.append(sin)
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
return freqs_cos, freqs_sin
class TimestepEmbedding(nn.Module): class TimestepEmbedding(nn.Module):
def __init__( def __init__(
self, self,
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -23,52 +23,18 @@ import torch.nn.functional as F ...@@ -23,52 +23,18 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...models.attention import FeedForward from ...models.attention import FeedForward
from ...models.attention_processor import Attention, FluxAttnProcessor2_0, FluxSingleAttnProcessor2_0 from ...models.attention_processor import Attention, FluxAttnProcessor2_0
from ...models.modeling_utils import ModelMixin from ...models.modeling_utils import ModelMixin
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph from ...utils.torch_utils import maybe_allow_in_graph
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
from ..modeling_outputs import Transformer2DModelOutput from ..modeling_outputs import Transformer2DModelOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# YiYi to-do: refactor rope related functions/classes
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
assert dim % 2 == 0, "The dimension must be even."
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
omega = 1.0 / (theta**scale)
batch_size, seq_length = pos.shape
out = torch.einsum("...n,d->...nd", pos, omega)
cos_out = torch.cos(out)
sin_out = torch.sin(out)
stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
return out.float()
# YiYi to-do: refactor rope related functions/classes
class EmbedND(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: List[int]):
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
emb = torch.cat(
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
dim=-3,
)
return emb.unsqueeze(1)
@maybe_allow_in_graph @maybe_allow_in_graph
class FluxSingleTransformerBlock(nn.Module): class FluxSingleTransformerBlock(nn.Module):
r""" r"""
...@@ -93,7 +59,7 @@ class FluxSingleTransformerBlock(nn.Module): ...@@ -93,7 +59,7 @@ class FluxSingleTransformerBlock(nn.Module):
self.act_mlp = nn.GELU(approximate="tanh") self.act_mlp = nn.GELU(approximate="tanh")
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
processor = FluxSingleAttnProcessor2_0() processor = FluxAttnProcessor2_0()
self.attn = Attention( self.attn = Attention(
query_dim=dim, query_dim=dim,
cross_attention_dim=None, cross_attention_dim=None,
...@@ -265,13 +231,14 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig ...@@ -265,13 +231,14 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
joint_attention_dim: int = 4096, joint_attention_dim: int = 4096,
pooled_projection_dim: int = 768, pooled_projection_dim: int = 768,
guidance_embeds: bool = False, guidance_embeds: bool = False,
axes_dims_rope: List[int] = [16, 56, 56], axes_dims_rope: Tuple[int] = (16, 56, 56),
): ):
super().__init__() super().__init__()
self.out_channels = in_channels self.out_channels = in_channels
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
self.pos_embed = EmbedND(dim=self.inner_dim, theta=10000, axes_dim=axes_dims_rope) self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
text_time_guidance_cls = ( text_time_guidance_cls = (
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
) )
...@@ -381,8 +348,19 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig ...@@ -381,8 +348,19 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
) )
encoder_hidden_states = self.context_embedder(encoder_hidden_states) encoder_hidden_states = self.context_embedder(encoder_hidden_states)
txt_ids = txt_ids.expand(img_ids.size(0), -1, -1) if txt_ids.ndim == 3:
ids = torch.cat((txt_ids, img_ids), dim=1) logger.warning(
"Passing `txt_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
txt_ids = txt_ids[0]
if img_ids.ndim == 3:
logger.warning(
"Passing `img_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
img_ids = img_ids[0]
ids = torch.cat((txt_ids, img_ids), dim=0)
image_rotary_emb = self.pos_embed(ids) image_rotary_emb = self.pos_embed(ids)
for index_block, block in enumerate(self.transformer_blocks): for index_block, block in enumerate(self.transformer_blocks):
......
...@@ -331,10 +331,6 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin): ...@@ -331,10 +331,6 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
scale_lora_layers(self.text_encoder_2, lora_scale) scale_lora_layers(self.text_encoder_2, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None: if prompt_embeds is None:
prompt_2 = prompt_2 or prompt prompt_2 = prompt_2 or prompt
...@@ -364,8 +360,7 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin): ...@@ -364,8 +360,7 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
unscale_lora_layers(self.text_encoder_2, lora_scale) unscale_lora_layers(self.text_encoder_2, lora_scale)
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
return prompt_embeds, pooled_prompt_embeds, text_ids return prompt_embeds, pooled_prompt_embeds, text_ids
...@@ -425,9 +420,8 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin): ...@@ -425,9 +420,8 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
latent_image_ids = latent_image_ids.reshape( latent_image_ids = latent_image_ids.reshape(
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels latent_image_id_height * latent_image_id_width, latent_image_id_channels
) )
return latent_image_ids.to(device=device, dtype=dtype) return latent_image_ids.to(device=device, dtype=dtype)
...@@ -724,7 +718,6 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin): ...@@ -724,7 +718,6 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
noise_pred = self.transformer( noise_pred = self.transformer(
hidden_states=latents, hidden_states=latents,
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
timestep=timestep / 1000, timestep=timestep / 1000,
guidance=guidance, guidance=guidance,
pooled_projections=pooled_prompt_embeds, pooled_projections=pooled_prompt_embeds,
......
...@@ -354,10 +354,6 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin): ...@@ -354,10 +354,6 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
scale_lora_layers(self.text_encoder_2, lora_scale) scale_lora_layers(self.text_encoder_2, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None: if prompt_embeds is None:
prompt_2 = prompt_2 or prompt prompt_2 = prompt_2 or prompt
...@@ -387,8 +383,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin): ...@@ -387,8 +383,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
unscale_lora_layers(self.text_encoder_2, lora_scale) unscale_lora_layers(self.text_encoder_2, lora_scale)
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
return prompt_embeds, pooled_prompt_embeds, text_ids return prompt_embeds, pooled_prompt_embeds, text_ids
...@@ -449,9 +444,8 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin): ...@@ -449,9 +444,8 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
latent_image_ids = latent_image_ids.reshape( latent_image_ids = latent_image_ids.reshape(
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels latent_image_id_height * latent_image_id_width, latent_image_id_channels
) )
return latent_image_ids.to(device=device, dtype=dtype) return latent_image_ids.to(device=device, dtype=dtype)
...@@ -804,7 +798,6 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin): ...@@ -804,7 +798,6 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
noise_pred = self.transformer( noise_pred = self.transformer(
hidden_states=latents, hidden_states=latents,
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
timestep=timestep / 1000, timestep=timestep / 1000,
guidance=guidance, guidance=guidance,
pooled_projections=pooled_prompt_embeds, pooled_projections=pooled_prompt_embeds,
......
...@@ -976,7 +976,6 @@ class ModelTesterMixin: ...@@ -976,7 +976,6 @@ class ModelTesterMixin:
self.assertTrue(actual_num_shards == expected_num_shards) self.assertTrue(actual_num_shards == expected_num_shards)
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto") new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto")
new_model = new_model.to(torch_device)
torch.manual_seed(0) torch.manual_seed(0)
if "generator" in inputs_dict: if "generator" in inputs_dict:
......
...@@ -44,8 +44,8 @@ class FluxTransformerTests(ModelTesterMixin, unittest.TestCase): ...@@ -44,8 +44,8 @@ class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device) hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(torch_device) pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(torch_device)
text_ids = torch.randn((batch_size, sequence_length, num_image_channels)).to(torch_device) text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device)
image_ids = torch.randn((batch_size, height * width, num_image_channels)).to(torch_device) image_ids = torch.randn((height * width, num_image_channels)).to(torch_device)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
return { return {
...@@ -80,3 +80,31 @@ class FluxTransformerTests(ModelTesterMixin, unittest.TestCase): ...@@ -80,3 +80,31 @@ class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
inputs_dict = self.dummy_input inputs_dict = self.dummy_input
return init_dict, inputs_dict return init_dict, inputs_dict
def test_deprecated_inputs_img_txt_ids_3d(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output_1 = model(**inputs_dict).to_tuple()[0]
# update inputs_dict with txt_ids and img_ids as 3d tensors (deprecated)
text_ids_3d = inputs_dict["txt_ids"].unsqueeze(0)
image_ids_3d = inputs_dict["img_ids"].unsqueeze(0)
assert text_ids_3d.ndim == 3, "text_ids_3d should be a 3d tensor"
assert image_ids_3d.ndim == 3, "img_ids_3d should be a 3d tensor"
inputs_dict["txt_ids"] = text_ids_3d
inputs_dict["img_ids"] = image_ids_3d
with torch.no_grad():
output_2 = model(**inputs_dict).to_tuple()[0]
self.assertEqual(output_1.shape, output_2.shape)
self.assertTrue(
torch.allclose(output_1, output_2, atol=1e-5),
msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs",
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment