Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
536c297a
Unverified
Commit
536c297a
authored
Sep 28, 2023
by
Pedro Cuenca
Committed by
GitHub
Sep 28, 2023
Browse files
Trickle down `split_head_dim` (#5208)
parent
693a0d08
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
47 additions
and
2 deletions
+47
-2
src/diffusers/models/attention_flax.py
src/diffusers/models/attention_flax.py
+23
-2
src/diffusers/models/unet_2d_blocks_flax.py
src/diffusers/models/unet_2d_blocks_flax.py
+15
-0
src/diffusers/models/unet_2d_condition_flax.py
src/diffusers/models/unet_2d_condition_flax.py
+7
-0
src/diffusers/pipelines/pipeline_flax_utils.py
src/diffusers/pipelines/pipeline_flax_utils.py
+2
-0
No files found.
src/diffusers/models/attention_flax.py
View file @
536c297a
...
@@ -258,6 +258,9 @@ class FlaxBasicTransformerBlock(nn.Module):
...
@@ -258,6 +258,9 @@ class FlaxBasicTransformerBlock(nn.Module):
Parameters `dtype`
Parameters `dtype`
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682
enable memory efficient attention https://arxiv.org/abs/2112.05682
split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
"""
"""
dim
:
int
dim
:
int
n_heads
:
int
n_heads
:
int
...
@@ -266,15 +269,28 @@ class FlaxBasicTransformerBlock(nn.Module):
...
@@ -266,15 +269,28 @@ class FlaxBasicTransformerBlock(nn.Module):
only_cross_attention
:
bool
=
False
only_cross_attention
:
bool
=
False
dtype
:
jnp
.
dtype
=
jnp
.
float32
dtype
:
jnp
.
dtype
=
jnp
.
float32
use_memory_efficient_attention
:
bool
=
False
use_memory_efficient_attention
:
bool
=
False
split_head_dim
:
bool
=
False
def
setup
(
self
):
def
setup
(
self
):
# self attention (or cross_attention if only_cross_attention is True)
# self attention (or cross_attention if only_cross_attention is True)
self
.
attn1
=
FlaxAttention
(
self
.
attn1
=
FlaxAttention
(
self
.
dim
,
self
.
n_heads
,
self
.
d_head
,
self
.
dropout
,
self
.
use_memory_efficient_attention
,
dtype
=
self
.
dtype
self
.
dim
,
self
.
n_heads
,
self
.
d_head
,
self
.
dropout
,
self
.
use_memory_efficient_attention
,
self
.
split_head_dim
,
dtype
=
self
.
dtype
,
)
)
# cross attention
# cross attention
self
.
attn2
=
FlaxAttention
(
self
.
attn2
=
FlaxAttention
(
self
.
dim
,
self
.
n_heads
,
self
.
d_head
,
self
.
dropout
,
self
.
use_memory_efficient_attention
,
dtype
=
self
.
dtype
self
.
dim
,
self
.
n_heads
,
self
.
d_head
,
self
.
dropout
,
self
.
use_memory_efficient_attention
,
self
.
split_head_dim
,
dtype
=
self
.
dtype
,
)
)
self
.
ff
=
FlaxFeedForward
(
dim
=
self
.
dim
,
dropout
=
self
.
dropout
,
dtype
=
self
.
dtype
)
self
.
ff
=
FlaxFeedForward
(
dim
=
self
.
dim
,
dropout
=
self
.
dropout
,
dtype
=
self
.
dtype
)
self
.
norm1
=
nn
.
LayerNorm
(
epsilon
=
1e-5
,
dtype
=
self
.
dtype
)
self
.
norm1
=
nn
.
LayerNorm
(
epsilon
=
1e-5
,
dtype
=
self
.
dtype
)
...
@@ -327,6 +343,9 @@ class FlaxTransformer2DModel(nn.Module):
...
@@ -327,6 +343,9 @@ class FlaxTransformer2DModel(nn.Module):
Parameters `dtype`
Parameters `dtype`
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682
enable memory efficient attention https://arxiv.org/abs/2112.05682
split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
"""
"""
in_channels
:
int
in_channels
:
int
n_heads
:
int
n_heads
:
int
...
@@ -337,6 +356,7 @@ class FlaxTransformer2DModel(nn.Module):
...
@@ -337,6 +356,7 @@ class FlaxTransformer2DModel(nn.Module):
only_cross_attention
:
bool
=
False
only_cross_attention
:
bool
=
False
dtype
:
jnp
.
dtype
=
jnp
.
float32
dtype
:
jnp
.
dtype
=
jnp
.
float32
use_memory_efficient_attention
:
bool
=
False
use_memory_efficient_attention
:
bool
=
False
split_head_dim
:
bool
=
False
def
setup
(
self
):
def
setup
(
self
):
self
.
norm
=
nn
.
GroupNorm
(
num_groups
=
32
,
epsilon
=
1e-5
)
self
.
norm
=
nn
.
GroupNorm
(
num_groups
=
32
,
epsilon
=
1e-5
)
...
@@ -362,6 +382,7 @@ class FlaxTransformer2DModel(nn.Module):
...
@@ -362,6 +382,7 @@ class FlaxTransformer2DModel(nn.Module):
only_cross_attention
=
self
.
only_cross_attention
,
only_cross_attention
=
self
.
only_cross_attention
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
use_memory_efficient_attention
=
self
.
use_memory_efficient_attention
,
use_memory_efficient_attention
=
self
.
use_memory_efficient_attention
,
split_head_dim
=
self
.
split_head_dim
,
)
)
for
_
in
range
(
self
.
depth
)
for
_
in
range
(
self
.
depth
)
]
]
...
...
src/diffusers/models/unet_2d_blocks_flax.py
View file @
536c297a
...
@@ -39,6 +39,9 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
...
@@ -39,6 +39,9 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
Whether to add downsampling layer before each final output
Whether to add downsampling layer before each final output
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682
enable memory efficient attention https://arxiv.org/abs/2112.05682
split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
Parameters `dtype`
"""
"""
...
@@ -51,6 +54,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
...
@@ -51,6 +54,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
use_linear_projection
:
bool
=
False
use_linear_projection
:
bool
=
False
only_cross_attention
:
bool
=
False
only_cross_attention
:
bool
=
False
use_memory_efficient_attention
:
bool
=
False
use_memory_efficient_attention
:
bool
=
False
split_head_dim
:
bool
=
False
dtype
:
jnp
.
dtype
=
jnp
.
float32
dtype
:
jnp
.
dtype
=
jnp
.
float32
transformer_layers_per_block
:
int
=
1
transformer_layers_per_block
:
int
=
1
...
@@ -77,6 +81,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
...
@@ -77,6 +81,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
use_linear_projection
=
self
.
use_linear_projection
,
use_linear_projection
=
self
.
use_linear_projection
,
only_cross_attention
=
self
.
only_cross_attention
,
only_cross_attention
=
self
.
only_cross_attention
,
use_memory_efficient_attention
=
self
.
use_memory_efficient_attention
,
use_memory_efficient_attention
=
self
.
use_memory_efficient_attention
,
split_head_dim
=
self
.
split_head_dim
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
)
)
attentions
.
append
(
attn_block
)
attentions
.
append
(
attn_block
)
...
@@ -179,6 +184,9 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
...
@@ -179,6 +184,9 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
Whether to add upsampling layer before each final output
Whether to add upsampling layer before each final output
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682
enable memory efficient attention https://arxiv.org/abs/2112.05682
split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
Parameters `dtype`
"""
"""
...
@@ -192,6 +200,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
...
@@ -192,6 +200,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
use_linear_projection
:
bool
=
False
use_linear_projection
:
bool
=
False
only_cross_attention
:
bool
=
False
only_cross_attention
:
bool
=
False
use_memory_efficient_attention
:
bool
=
False
use_memory_efficient_attention
:
bool
=
False
split_head_dim
:
bool
=
False
dtype
:
jnp
.
dtype
=
jnp
.
float32
dtype
:
jnp
.
dtype
=
jnp
.
float32
transformer_layers_per_block
:
int
=
1
transformer_layers_per_block
:
int
=
1
...
@@ -219,6 +228,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
...
@@ -219,6 +228,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
use_linear_projection
=
self
.
use_linear_projection
,
use_linear_projection
=
self
.
use_linear_projection
,
only_cross_attention
=
self
.
only_cross_attention
,
only_cross_attention
=
self
.
only_cross_attention
,
use_memory_efficient_attention
=
self
.
use_memory_efficient_attention
,
use_memory_efficient_attention
=
self
.
use_memory_efficient_attention
,
split_head_dim
=
self
.
split_head_dim
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
)
)
attentions
.
append
(
attn_block
)
attentions
.
append
(
attn_block
)
...
@@ -323,6 +333,9 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
...
@@ -323,6 +333,9 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
Number of attention heads of each spatial transformer block
Number of attention heads of each spatial transformer block
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682
enable memory efficient attention https://arxiv.org/abs/2112.05682
split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
Parameters `dtype`
"""
"""
...
@@ -332,6 +345,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
...
@@ -332,6 +345,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
num_attention_heads
:
int
=
1
num_attention_heads
:
int
=
1
use_linear_projection
:
bool
=
False
use_linear_projection
:
bool
=
False
use_memory_efficient_attention
:
bool
=
False
use_memory_efficient_attention
:
bool
=
False
split_head_dim
:
bool
=
False
dtype
:
jnp
.
dtype
=
jnp
.
float32
dtype
:
jnp
.
dtype
=
jnp
.
float32
transformer_layers_per_block
:
int
=
1
transformer_layers_per_block
:
int
=
1
...
@@ -356,6 +370,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
...
@@ -356,6 +370,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
depth
=
self
.
transformer_layers_per_block
,
depth
=
self
.
transformer_layers_per_block
,
use_linear_projection
=
self
.
use_linear_projection
,
use_linear_projection
=
self
.
use_linear_projection
,
use_memory_efficient_attention
=
self
.
use_memory_efficient_attention
,
use_memory_efficient_attention
=
self
.
use_memory_efficient_attention
,
split_head_dim
=
self
.
split_head_dim
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
)
)
attentions
.
append
(
attn_block
)
attentions
.
append
(
attn_block
)
...
...
src/diffusers/models/unet_2d_condition_flax.py
View file @
536c297a
...
@@ -92,6 +92,9 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
...
@@ -92,6 +92,9 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
Enable memory efficient attention as described [here](https://arxiv.org/abs/2112.05682).
Enable memory efficient attention as described [here](https://arxiv.org/abs/2112.05682).
split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
"""
"""
sample_size
:
int
=
32
sample_size
:
int
=
32
...
@@ -116,6 +119,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
...
@@ -116,6 +119,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
flip_sin_to_cos
:
bool
=
True
flip_sin_to_cos
:
bool
=
True
freq_shift
:
int
=
0
freq_shift
:
int
=
0
use_memory_efficient_attention
:
bool
=
False
use_memory_efficient_attention
:
bool
=
False
split_head_dim
:
bool
=
False
transformer_layers_per_block
:
Union
[
int
,
Tuple
[
int
]]
=
1
transformer_layers_per_block
:
Union
[
int
,
Tuple
[
int
]]
=
1
addition_embed_type
:
Optional
[
str
]
=
None
addition_embed_type
:
Optional
[
str
]
=
None
addition_time_embed_dim
:
Optional
[
int
]
=
None
addition_time_embed_dim
:
Optional
[
int
]
=
None
...
@@ -231,6 +235,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
...
@@ -231,6 +235,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
use_linear_projection
=
self
.
use_linear_projection
,
use_linear_projection
=
self
.
use_linear_projection
,
only_cross_attention
=
only_cross_attention
[
i
],
only_cross_attention
=
only_cross_attention
[
i
],
use_memory_efficient_attention
=
self
.
use_memory_efficient_attention
,
use_memory_efficient_attention
=
self
.
use_memory_efficient_attention
,
split_head_dim
=
self
.
split_head_dim
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
)
)
else
:
else
:
...
@@ -254,6 +259,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
...
@@ -254,6 +259,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
transformer_layers_per_block
=
transformer_layers_per_block
[
-
1
],
transformer_layers_per_block
=
transformer_layers_per_block
[
-
1
],
use_linear_projection
=
self
.
use_linear_projection
,
use_linear_projection
=
self
.
use_linear_projection
,
use_memory_efficient_attention
=
self
.
use_memory_efficient_attention
,
use_memory_efficient_attention
=
self
.
use_memory_efficient_attention
,
split_head_dim
=
self
.
split_head_dim
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
)
)
...
@@ -284,6 +290,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
...
@@ -284,6 +290,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
use_linear_projection
=
self
.
use_linear_projection
,
use_linear_projection
=
self
.
use_linear_projection
,
only_cross_attention
=
only_cross_attention
[
i
],
only_cross_attention
=
only_cross_attention
[
i
],
use_memory_efficient_attention
=
self
.
use_memory_efficient_attention
,
use_memory_efficient_attention
=
self
.
use_memory_efficient_attention
,
split_head_dim
=
self
.
split_head_dim
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
)
)
else
:
else
:
...
...
src/diffusers/pipelines/pipeline_flax_utils.py
View file @
536c297a
...
@@ -323,6 +323,7 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
...
@@ -323,6 +323,7 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
revision
=
kwargs
.
pop
(
"revision"
,
None
)
revision
=
kwargs
.
pop
(
"revision"
,
None
)
from_pt
=
kwargs
.
pop
(
"from_pt"
,
False
)
from_pt
=
kwargs
.
pop
(
"from_pt"
,
False
)
use_memory_efficient_attention
=
kwargs
.
pop
(
"use_memory_efficient_attention"
,
False
)
use_memory_efficient_attention
=
kwargs
.
pop
(
"use_memory_efficient_attention"
,
False
)
split_head_dim
=
kwargs
.
pop
(
"split_head_dim"
,
False
)
dtype
=
kwargs
.
pop
(
"dtype"
,
None
)
dtype
=
kwargs
.
pop
(
"dtype"
,
None
)
# 1. Download the checkpoints and configs
# 1. Download the checkpoints and configs
...
@@ -501,6 +502,7 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
...
@@ -501,6 +502,7 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
loadable_folder
,
loadable_folder
,
from_pt
=
from_pt
,
from_pt
=
from_pt
,
use_memory_efficient_attention
=
use_memory_efficient_attention
,
use_memory_efficient_attention
=
use_memory_efficient_attention
,
split_head_dim
=
split_head_dim
,
dtype
=
dtype
,
dtype
=
dtype
,
)
)
params
[
name
]
=
loaded_params
params
[
name
]
=
loaded_params
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment