Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
536c297a
"vscode:/vscode.git/clone" did not exist on "e6a15c1a42a4792e39f7cfe99ce5c6c8ae5bbbb9"
Unverified
Commit
536c297a
authored
Sep 28, 2023
by
Pedro Cuenca
Committed by
GitHub
Sep 28, 2023
Browse files
Trickle down `split_head_dim` (#5208)
parent
693a0d08
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
47 additions
and
2 deletions
+47
-2
src/diffusers/models/attention_flax.py
src/diffusers/models/attention_flax.py
+23
-2
src/diffusers/models/unet_2d_blocks_flax.py
src/diffusers/models/unet_2d_blocks_flax.py
+15
-0
src/diffusers/models/unet_2d_condition_flax.py
src/diffusers/models/unet_2d_condition_flax.py
+7
-0
src/diffusers/pipelines/pipeline_flax_utils.py
src/diffusers/pipelines/pipeline_flax_utils.py
+2
-0
No files found.
src/diffusers/models/attention_flax.py
View file @
536c297a
...
@@ -258,6 +258,9 @@ class FlaxBasicTransformerBlock(nn.Module):
...
@@ -258,6 +258,9 @@ class FlaxBasicTransformerBlock(nn.Module):
Parameters `dtype`
Parameters `dtype`
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682
enable memory efficient attention https://arxiv.org/abs/2112.05682
split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
"""
"""
dim
:
int
dim
:
int
n_heads
:
int
n_heads
:
int
...
@@ -266,15 +269,28 @@ class FlaxBasicTransformerBlock(nn.Module):
...
@@ -266,15 +269,28 @@ class FlaxBasicTransformerBlock(nn.Module):
only_cross_attention
:
bool
=
False
only_cross_attention
:
bool
=
False
dtype
:
jnp
.
dtype
=
jnp
.
float32
dtype
:
jnp
.
dtype
=
jnp
.
float32
use_memory_efficient_attention
:
bool
=
False
use_memory_efficient_attention
:
bool
=
False
split_head_dim
:
bool
=
False
def
setup
(
self
):
def
setup
(
self
):
# self attention (or cross_attention if only_cross_attention is True)
# self attention (or cross_attention if only_cross_attention is True)
self
.
attn1
=
FlaxAttention
(
self
.
attn1
=
FlaxAttention
(
self
.
dim
,
self
.
n_heads
,
self
.
d_head
,
self
.
dropout
,
self
.
use_memory_efficient_attention
,
dtype
=
self
.
dtype
self
.
dim
,
self
.
n_heads
,
self
.
d_head
,
self
.
dropout
,
self
.
use_memory_efficient_attention
,
self
.
split_head_dim
,
dtype
=
self
.
dtype
,
)
)
# cross attention
# cross attention
self
.
attn2
=
FlaxAttention
(
self
.
attn2
=
FlaxAttention
(
self
.
dim
,
self
.
n_heads
,
self
.
d_head
,
self
.
dropout
,
self
.
use_memory_efficient_attention
,
dtype
=
self
.
dtype
self
.
dim
,
self
.
n_heads
,
self
.
d_head
,
self
.
dropout
,
self
.
use_memory_efficient_attention
,
self
.
split_head_dim
,
dtype
=
self
.
dtype
,
)
)
self
.
ff
=
FlaxFeedForward
(
dim
=
self
.
dim
,
dropout
=
self
.
dropout
,
dtype
=
self
.
dtype
)
self
.
ff
=
FlaxFeedForward
(
dim
=
self
.
dim
,
dropout
=
self
.
dropout
,
dtype
=
self
.
dtype
)
self
.
norm1
=
nn
.
LayerNorm
(
epsilon
=
1e-5
,
dtype
=
self
.
dtype
)
self
.
norm1
=
nn
.
LayerNorm
(
epsilon
=
1e-5
,
dtype
=
self
.
dtype
)
...
@@ -327,6 +343,9 @@ class FlaxTransformer2DModel(nn.Module):
...
@@ -327,6 +343,9 @@ class FlaxTransformer2DModel(nn.Module):
Parameters `dtype`
Parameters `dtype`
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682
enable memory efficient attention https://arxiv.org/abs/2112.05682
split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
"""
"""
in_channels
:
int
in_channels
:
int
n_heads
:
int
n_heads
:
int
...
@@ -337,6 +356,7 @@ class FlaxTransformer2DModel(nn.Module):
...
@@ -337,6 +356,7 @@ class FlaxTransformer2DModel(nn.Module):
only_cross_attention
:
bool
=
False
only_cross_attention
:
bool
=
False
dtype
:
jnp
.
dtype
=
jnp
.
float32
dtype
:
jnp
.
dtype
=
jnp
.
float32
use_memory_efficient_attention
:
bool
=
False
use_memory_efficient_attention
:
bool
=
False
split_head_dim
:
bool
=
False
def
setup
(
self
):
def
setup
(
self
):
self
.
norm
=
nn
.
GroupNorm
(
num_groups
=
32
,
epsilon
=
1e-5
)
self
.
norm
=
nn
.
GroupNorm
(
num_groups
=
32
,
epsilon
=
1e-5
)
...
@@ -362,6 +382,7 @@ class FlaxTransformer2DModel(nn.Module):
...
@@ -362,6 +382,7 @@ class FlaxTransformer2DModel(nn.Module):
only_cross_attention
=
self
.
only_cross_attention
,
only_cross_attention
=
self
.
only_cross_attention
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
use_memory_efficient_attention
=
self
.
use_memory_efficient_attention
,
use_memory_efficient_attention
=
self
.
use_memory_efficient_attention
,
split_head_dim
=
self
.
split_head_dim
,
)
)
for
_
in
range
(
self
.
depth
)
for
_
in
range
(
self
.
depth
)
]
]
...
...
src/diffusers/models/unet_2d_blocks_flax.py
View file @
536c297a
...
@@ -39,6 +39,9 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
...
@@ -39,6 +39,9 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
Whether to add downsampling layer before each final output
Whether to add downsampling layer before each final output
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682
enable memory efficient attention https://arxiv.org/abs/2112.05682
split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
Parameters `dtype`
"""
"""
...
@@ -51,6 +54,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
...
@@ -51,6 +54,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
use_linear_projection
:
bool
=
False
use_linear_projection
:
bool
=
False
only_cross_attention
:
bool
=
False
only_cross_attention
:
bool
=
False
use_memory_efficient_attention
:
bool
=
False
use_memory_efficient_attention
:
bool
=
False
split_head_dim
:
bool
=
False
dtype
:
jnp
.
dtype
=
jnp
.
float32
dtype
:
jnp
.
dtype
=
jnp
.
float32
transformer_layers_per_block
:
int
=
1
transformer_layers_per_block
:
int
=
1
...
@@ -77,6 +81,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
...
@@ -77,6 +81,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
use_linear_projection
=
self
.
use_linear_projection
,
use_linear_projection
=
self
.
use_linear_projection
,
only_cross_attention
=
self
.
only_cross_attention
,
only_cross_attention
=
self
.
only_cross_attention
,
use_memory_efficient_attention
=
self
.
use_memory_efficient_attention
,
use_memory_efficient_attention
=
self
.
use_memory_efficient_attention
,
split_head_dim
=
self
.
split_head_dim
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
)
)
attentions
.
append
(
attn_block
)
attentions
.
append
(
attn_block
)
...
@@ -179,6 +184,9 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
...
@@ -179,6 +184,9 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
Whether to add upsampling layer before each final output
Whether to add upsampling layer before each final output
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682
enable memory efficient attention https://arxiv.org/abs/2112.05682
split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
Parameters `dtype`
"""
"""
...
@@ -192,6 +200,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
...
@@ -192,6 +200,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
use_linear_projection
:
bool
=
False
use_linear_projection
:
bool
=
False
only_cross_attention
:
bool
=
False
only_cross_attention
:
bool
=
False
use_memory_efficient_attention
:
bool
=
False
use_memory_efficient_attention
:
bool
=
False
split_head_dim
:
bool
=
False
dtype
:
jnp
.
dtype
=
jnp
.
float32
dtype
:
jnp
.
dtype
=
jnp
.
float32
transformer_layers_per_block
:
int
=
1
transformer_layers_per_block
:
int
=
1
...
@@ -219,6 +228,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
...
@@ -219,6 +228,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
use_linear_projection
=
self
.
use_linear_projection
,
use_linear_projection
=
self
.
use_linear_projection
,
only_cross_attention
=
self
.
only_cross_attention
,
only_cross_attention
=
self
.
only_cross_attention
,
use_memory_efficient_attention
=
self
.
use_memory_efficient_attention
,
use_memory_efficient_attention
=
self
.
use_memory_efficient_attention
,
split_head_dim
=
self
.
split_head_dim
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
)
)
attentions
.
append
(
attn_block
)
attentions
.
append
(
attn_block
)
...
@@ -323,6 +333,9 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
...
@@ -323,6 +333,9 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
Number of attention heads of each spatial transformer block
Number of attention heads of each spatial transformer block
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682
enable memory efficient attention https://arxiv.org/abs/2112.05682
split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
Parameters `dtype`
"""
"""
...
@@ -332,6 +345,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
...
@@ -332,6 +345,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
num_attention_heads
:
int
=
1
num_attention_heads
:
int
=
1
use_linear_projection
:
bool
=
False
use_linear_projection
:
bool
=
False
use_memory_efficient_attention
:
bool
=
False
use_memory_efficient_attention
:
bool
=
False
split_head_dim
:
bool
=
False
dtype
:
jnp
.
dtype
=
jnp
.
float32
dtype
:
jnp
.
dtype
=
jnp
.
float32
transformer_layers_per_block
:
int
=
1
transformer_layers_per_block
:
int
=
1
...
@@ -356,6 +370,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
...
@@ -356,6 +370,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
depth
=
self
.
transformer_layers_per_block
,
depth
=
self
.
transformer_layers_per_block
,
use_linear_projection
=
self
.
use_linear_projection
,
use_linear_projection
=
self
.
use_linear_projection
,
use_memory_efficient_attention
=
self
.
use_memory_efficient_attention
,
use_memory_efficient_attention
=
self
.
use_memory_efficient_attention
,
split_head_dim
=
self
.
split_head_dim
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
)
)
attentions
.
append
(
attn_block
)
attentions
.
append
(
attn_block
)
...
...
src/diffusers/models/unet_2d_condition_flax.py
View file @
536c297a
...
@@ -92,6 +92,9 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
...
@@ -92,6 +92,9 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
Enable memory efficient attention as described [here](https://arxiv.org/abs/2112.05682).
Enable memory efficient attention as described [here](https://arxiv.org/abs/2112.05682).
split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
"""
"""
sample_size
:
int
=
32
sample_size
:
int
=
32
...
@@ -116,6 +119,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
...
@@ -116,6 +119,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
flip_sin_to_cos
:
bool
=
True
flip_sin_to_cos
:
bool
=
True
freq_shift
:
int
=
0
freq_shift
:
int
=
0
use_memory_efficient_attention
:
bool
=
False
use_memory_efficient_attention
:
bool
=
False
split_head_dim
:
bool
=
False
transformer_layers_per_block
:
Union
[
int
,
Tuple
[
int
]]
=
1
transformer_layers_per_block
:
Union
[
int
,
Tuple
[
int
]]
=
1
addition_embed_type
:
Optional
[
str
]
=
None
addition_embed_type
:
Optional
[
str
]
=
None
addition_time_embed_dim
:
Optional
[
int
]
=
None
addition_time_embed_dim
:
Optional
[
int
]
=
None
...
@@ -231,6 +235,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
...
@@ -231,6 +235,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
use_linear_projection
=
self
.
use_linear_projection
,
use_linear_projection
=
self
.
use_linear_projection
,
only_cross_attention
=
only_cross_attention
[
i
],
only_cross_attention
=
only_cross_attention
[
i
],
use_memory_efficient_attention
=
self
.
use_memory_efficient_attention
,
use_memory_efficient_attention
=
self
.
use_memory_efficient_attention
,
split_head_dim
=
self
.
split_head_dim
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
)
)
else
:
else
:
...
@@ -254,6 +259,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
...
@@ -254,6 +259,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
transformer_layers_per_block
=
transformer_layers_per_block
[
-
1
],
transformer_layers_per_block
=
transformer_layers_per_block
[
-
1
],
use_linear_projection
=
self
.
use_linear_projection
,
use_linear_projection
=
self
.
use_linear_projection
,
use_memory_efficient_attention
=
self
.
use_memory_efficient_attention
,
use_memory_efficient_attention
=
self
.
use_memory_efficient_attention
,
split_head_dim
=
self
.
split_head_dim
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
)
)
...
@@ -284,6 +290,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
...
@@ -284,6 +290,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
use_linear_projection
=
self
.
use_linear_projection
,
use_linear_projection
=
self
.
use_linear_projection
,
only_cross_attention
=
only_cross_attention
[
i
],
only_cross_attention
=
only_cross_attention
[
i
],
use_memory_efficient_attention
=
self
.
use_memory_efficient_attention
,
use_memory_efficient_attention
=
self
.
use_memory_efficient_attention
,
split_head_dim
=
self
.
split_head_dim
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
)
)
else
:
else
:
...
...
src/diffusers/pipelines/pipeline_flax_utils.py
View file @
536c297a
...
@@ -323,6 +323,7 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
...
@@ -323,6 +323,7 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
revision
=
kwargs
.
pop
(
"revision"
,
None
)
revision
=
kwargs
.
pop
(
"revision"
,
None
)
from_pt
=
kwargs
.
pop
(
"from_pt"
,
False
)
from_pt
=
kwargs
.
pop
(
"from_pt"
,
False
)
use_memory_efficient_attention
=
kwargs
.
pop
(
"use_memory_efficient_attention"
,
False
)
use_memory_efficient_attention
=
kwargs
.
pop
(
"use_memory_efficient_attention"
,
False
)
split_head_dim
=
kwargs
.
pop
(
"split_head_dim"
,
False
)
dtype
=
kwargs
.
pop
(
"dtype"
,
None
)
dtype
=
kwargs
.
pop
(
"dtype"
,
None
)
# 1. Download the checkpoints and configs
# 1. Download the checkpoints and configs
...
@@ -501,6 +502,7 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
...
@@ -501,6 +502,7 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
loadable_folder
,
loadable_folder
,
from_pt
=
from_pt
,
from_pt
=
from_pt
,
use_memory_efficient_attention
=
use_memory_efficient_attention
,
use_memory_efficient_attention
=
use_memory_efficient_attention
,
split_head_dim
=
split_head_dim
,
dtype
=
dtype
,
dtype
=
dtype
,
)
)
params
[
name
]
=
loaded_params
params
[
name
]
=
loaded_params
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment