Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
476795c5
Unverified
Commit
476795c5
authored
Jan 03, 2025
by
Aryan
Committed by
GitHub
Jan 02, 2025
Browse files
Update Flux docstrings (#10423)
update
parent
3cb66865
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
63 additions
and
42 deletions
+63
-42
src/diffusers/models/transformers/transformer_flux.py
src/diffusers/models/transformers/transformer_flux.py
+63
-42
No files found.
src/diffusers/models/transformers/transformer_flux.py
View file @
476795c5
...
...
@@ -85,11 +85,11 @@ class FluxSingleTransformerBlock(nn.Module):
def
forward
(
self
,
hidden_states
:
torch
.
Float
Tensor
,
temb
:
torch
.
Float
Tensor
,
image_rotary_emb
=
None
,
joint_attention_kwargs
=
None
,
):
hidden_states
:
torch
.
Tensor
,
temb
:
torch
.
Tensor
,
image_rotary_emb
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
None
,
joint_attention_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
torch
.
Tensor
:
residual
=
hidden_states
norm_hidden_states
,
gate
=
self
.
norm
(
hidden_states
,
emb
=
temb
)
mlp_hidden_states
=
self
.
act_mlp
(
self
.
proj_mlp
(
norm_hidden_states
))
...
...
@@ -117,15 +117,22 @@ class FluxTransformerBlock(nn.Module):
Reference: https://arxiv.org/abs/2403.03206
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
processing of `context` conditions.
Args:
dim (`int`):
The embedding dimension of the block.
num_attention_heads (`int`):
The number of attention heads to use.
attention_head_dim (`int`):
The number of dimensions to use for each attention head.
qk_norm (`str`, defaults to `"rms_norm"`):
The normalization to use for the query and key tensors.
eps (`float`, defaults to `1e-6`):
The epsilon value to use for the normalization.
"""
def
__init__
(
self
,
dim
,
num_attention_heads
,
attention_head_dim
,
qk_norm
=
"rms_norm"
,
eps
=
1e-6
):
def
__init__
(
self
,
dim
:
int
,
num_attention_heads
:
int
,
attention_head_dim
:
int
,
qk_norm
:
str
=
"rms_norm"
,
eps
:
float
=
1e-6
):
super
().
__init__
()
self
.
norm1
=
AdaLayerNormZero
(
dim
)
...
...
@@ -164,12 +171,12 @@ class FluxTransformerBlock(nn.Module):
def
forward
(
self
,
hidden_states
:
torch
.
Float
Tensor
,
encoder_hidden_states
:
torch
.
Float
Tensor
,
temb
:
torch
.
Float
Tensor
,
image_rotary_emb
=
None
,
joint_attention_kwargs
=
None
,
):
hidden_states
:
torch
.
Tensor
,
encoder_hidden_states
:
torch
.
Tensor
,
temb
:
torch
.
Tensor
,
image_rotary_emb
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
None
,
joint_attention_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
norm_hidden_states
,
gate_msa
,
shift_mlp
,
scale_mlp
,
gate_mlp
=
self
.
norm1
(
hidden_states
,
emb
=
temb
)
norm_encoder_hidden_states
,
c_gate_msa
,
c_shift_mlp
,
c_scale_mlp
,
c_gate_mlp
=
self
.
norm1_context
(
...
...
@@ -227,16 +234,30 @@ class FluxTransformer2DModel(
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
Parameters:
patch_size (`int`): Patch size to turn the input data into small patches.
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
Args:
patch_size (`int`, defaults to `1`):
Patch size to turn the input data into small patches.
in_channels (`int`, defaults to `64`):
The number of channels in the input.
out_channels (`int`, *optional*, defaults to `None`):
The number of channels in the output. If not specified, it defaults to `in_channels`.
num_layers (`int`, defaults to `19`):
The number of layers of dual stream DiT blocks to use.
num_single_layers (`int`, defaults to `38`):
The number of layers of single stream DiT blocks to use.
attention_head_dim (`int`, defaults to `128`):
The number of dimensions to use for each attention head.
num_attention_heads (`int`, defaults to `24`):
The number of attention heads to use.
joint_attention_dim (`int`, defaults to `4096`):
The number of dimensions to use for the joint attention (embedding/channel dimension of
`encoder_hidden_states`).
pooled_projection_dim (`int`, defaults to `768`):
The number of dimensions to use for the pooled projection.
guidance_embeds (`bool`, defaults to `False`):
Whether to use guidance embeddings for guidance-distilled variant of the model.
axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
The dimensions to use for the rotary positional embeddings.
"""
_supports_gradient_checkpointing
=
True
...
...
@@ -259,7 +280,7 @@ class FluxTransformer2DModel(
):
super
().
__init__
()
self
.
out_channels
=
out_channels
or
in_channels
self
.
inner_dim
=
self
.
config
.
num_attention_heads
*
self
.
config
.
attention_head_dim
self
.
inner_dim
=
num_attention_heads
*
attention_head_dim
self
.
pos_embed
=
FluxPosEmbed
(
theta
=
10000
,
axes_dim
=
axes_dims_rope
)
...
...
@@ -267,20 +288,20 @@ class FluxTransformer2DModel(
CombinedTimestepGuidanceTextProjEmbeddings
if
guidance_embeds
else
CombinedTimestepTextProjEmbeddings
)
self
.
time_text_embed
=
text_time_guidance_cls
(
embedding_dim
=
self
.
inner_dim
,
pooled_projection_dim
=
self
.
config
.
pooled_projection_dim
embedding_dim
=
self
.
inner_dim
,
pooled_projection_dim
=
pooled_projection_dim
)
self
.
context_embedder
=
nn
.
Linear
(
self
.
config
.
joint_attention_dim
,
self
.
inner_dim
)
self
.
x_embedder
=
nn
.
Linear
(
self
.
config
.
in_channels
,
self
.
inner_dim
)
self
.
context_embedder
=
nn
.
Linear
(
joint_attention_dim
,
self
.
inner_dim
)
self
.
x_embedder
=
nn
.
Linear
(
in_channels
,
self
.
inner_dim
)
self
.
transformer_blocks
=
nn
.
ModuleList
(
[
FluxTransformerBlock
(
dim
=
self
.
inner_dim
,
num_attention_heads
=
self
.
config
.
num_attention_heads
,
attention_head_dim
=
self
.
config
.
attention_head_dim
,
num_attention_heads
=
num_attention_heads
,
attention_head_dim
=
attention_head_dim
,
)
for
i
in
range
(
self
.
config
.
num_layers
)
for
_
in
range
(
num_layers
)
]
)
...
...
@@ -288,10 +309,10 @@ class FluxTransformer2DModel(
[
FluxSingleTransformerBlock
(
dim
=
self
.
inner_dim
,
num_attention_heads
=
self
.
config
.
num_attention_heads
,
attention_head_dim
=
self
.
config
.
attention_head_dim
,
num_attention_heads
=
num_attention_heads
,
attention_head_dim
=
attention_head_dim
,
)
for
i
in
range
(
self
.
config
.
num_single_layers
)
for
_
in
range
(
num_single_layers
)
]
)
...
...
@@ -418,16 +439,16 @@ class FluxTransformer2DModel(
controlnet_single_block_samples
=
None
,
return_dict
:
bool
=
True
,
controlnet_blocks_repeat
:
bool
=
False
,
)
->
Union
[
torch
.
Float
Tensor
,
Transformer2DModelOutput
]:
)
->
Union
[
torch
.
Tensor
,
Transformer2DModelOutput
]:
"""
The [`FluxTransformer2DModel`] forward method.
Args:
hidden_states (`torch.
Float
Tensor` of shape `(batch
size,
channel, height, width
)`):
hidden_states (`torch.Tensor` of shape `(batch
_
size,
image_sequence_length, in_channels
)`):
Input `hidden_states`.
encoder_hidden_states (`torch.
Float
Tensor` of shape `(batch
size, sequence_len
, embed
_dim
s
)`):
encoder_hidden_states (`torch.Tensor` of shape `(batch
_
size,
text_
sequence_len
gth, joint_attention
_dim)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.
Float
Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment