Unverified Commit 476795c5 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

Update Flux docstrings (#10423)

update
parent 3cb66865
...@@ -85,11 +85,11 @@ class FluxSingleTransformerBlock(nn.Module): ...@@ -85,11 +85,11 @@ class FluxSingleTransformerBlock(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
temb: torch.FloatTensor, temb: torch.Tensor,
image_rotary_emb=None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs=None, joint_attention_kwargs: Optional[Dict[str, Any]] = None,
): ) -> torch.Tensor:
residual = hidden_states residual = hidden_states
norm_hidden_states, gate = self.norm(hidden_states, emb=temb) norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
...@@ -117,15 +117,22 @@ class FluxTransformerBlock(nn.Module): ...@@ -117,15 +117,22 @@ class FluxTransformerBlock(nn.Module):
Reference: https://arxiv.org/abs/2403.03206 Reference: https://arxiv.org/abs/2403.03206
Parameters: Args:
dim (`int`): The number of channels in the input and output. dim (`int`):
num_attention_heads (`int`): The number of heads to use for multi-head attention. The embedding dimension of the block.
attention_head_dim (`int`): The number of channels in each head. num_attention_heads (`int`):
context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the The number of attention heads to use.
processing of `context` conditions. attention_head_dim (`int`):
The number of dimensions to use for each attention head.
qk_norm (`str`, defaults to `"rms_norm"`):
The normalization to use for the query and key tensors.
eps (`float`, defaults to `1e-6`):
The epsilon value to use for the normalization.
""" """
def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6): def __init__(
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
):
super().__init__() super().__init__()
self.norm1 = AdaLayerNormZero(dim) self.norm1 = AdaLayerNormZero(dim)
...@@ -164,12 +171,12 @@ class FluxTransformerBlock(nn.Module): ...@@ -164,12 +171,12 @@ class FluxTransformerBlock(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
encoder_hidden_states: torch.FloatTensor, encoder_hidden_states: torch.Tensor,
temb: torch.FloatTensor, temb: torch.Tensor,
image_rotary_emb=None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs=None, joint_attention_kwargs: Optional[Dict[str, Any]] = None,
): ) -> Tuple[torch.Tensor, torch.Tensor]:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
...@@ -227,16 +234,30 @@ class FluxTransformer2DModel( ...@@ -227,16 +234,30 @@ class FluxTransformer2DModel(
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
Parameters: Args:
patch_size (`int`): Patch size to turn the input data into small patches. patch_size (`int`, defaults to `1`):
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input. Patch size to turn the input data into small patches.
num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use. in_channels (`int`, defaults to `64`):
num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use. The number of channels in the input.
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. out_channels (`int`, *optional*, defaults to `None`):
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. The number of channels in the output. If not specified, it defaults to `in_channels`.
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. num_layers (`int`, defaults to `19`):
pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`. The number of layers of dual stream DiT blocks to use.
guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings. num_single_layers (`int`, defaults to `38`):
The number of layers of single stream DiT blocks to use.
attention_head_dim (`int`, defaults to `128`):
The number of dimensions to use for each attention head.
num_attention_heads (`int`, defaults to `24`):
The number of attention heads to use.
joint_attention_dim (`int`, defaults to `4096`):
The number of dimensions to use for the joint attention (embedding/channel dimension of
`encoder_hidden_states`).
pooled_projection_dim (`int`, defaults to `768`):
The number of dimensions to use for the pooled projection.
guidance_embeds (`bool`, defaults to `False`):
Whether to use guidance embeddings for guidance-distilled variant of the model.
axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
The dimensions to use for the rotary positional embeddings.
""" """
_supports_gradient_checkpointing = True _supports_gradient_checkpointing = True
...@@ -259,7 +280,7 @@ class FluxTransformer2DModel( ...@@ -259,7 +280,7 @@ class FluxTransformer2DModel(
): ):
super().__init__() super().__init__()
self.out_channels = out_channels or in_channels self.out_channels = out_channels or in_channels
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim self.inner_dim = num_attention_heads * attention_head_dim
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
...@@ -267,20 +288,20 @@ class FluxTransformer2DModel( ...@@ -267,20 +288,20 @@ class FluxTransformer2DModel(
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
) )
self.time_text_embed = text_time_guidance_cls( self.time_text_embed = text_time_guidance_cls(
embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
) )
self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim) self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
self.x_embedder = nn.Linear(self.config.in_channels, self.inner_dim) self.x_embedder = nn.Linear(in_channels, self.inner_dim)
self.transformer_blocks = nn.ModuleList( self.transformer_blocks = nn.ModuleList(
[ [
FluxTransformerBlock( FluxTransformerBlock(
dim=self.inner_dim, dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads, num_attention_heads=num_attention_heads,
attention_head_dim=self.config.attention_head_dim, attention_head_dim=attention_head_dim,
) )
for i in range(self.config.num_layers) for _ in range(num_layers)
] ]
) )
...@@ -288,10 +309,10 @@ class FluxTransformer2DModel( ...@@ -288,10 +309,10 @@ class FluxTransformer2DModel(
[ [
FluxSingleTransformerBlock( FluxSingleTransformerBlock(
dim=self.inner_dim, dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads, num_attention_heads=num_attention_heads,
attention_head_dim=self.config.attention_head_dim, attention_head_dim=attention_head_dim,
) )
for i in range(self.config.num_single_layers) for _ in range(num_single_layers)
] ]
) )
...@@ -418,16 +439,16 @@ class FluxTransformer2DModel( ...@@ -418,16 +439,16 @@ class FluxTransformer2DModel(
controlnet_single_block_samples=None, controlnet_single_block_samples=None,
return_dict: bool = True, return_dict: bool = True,
controlnet_blocks_repeat: bool = False, controlnet_blocks_repeat: bool = False,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]: ) -> Union[torch.Tensor, Transformer2DModelOutput]:
""" """
The [`FluxTransformer2DModel`] forward method. The [`FluxTransformer2DModel`] forward method.
Args: Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
Input `hidden_states`. Input `hidden_states`.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions. from the embeddings of input conditions.
timestep ( `torch.LongTensor`): timestep ( `torch.LongTensor`):
Used to indicate denoising step. Used to indicate denoising step.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment