Unverified Commit 536c297a authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Trickle down `split_head_dim` (#5208)

parent 693a0d08
......@@ -258,6 +258,9 @@ class FlaxBasicTransformerBlock(nn.Module):
Parameters `dtype`
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682
split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
"""
dim: int
n_heads: int
......@@ -266,15 +269,28 @@ class FlaxBasicTransformerBlock(nn.Module):
only_cross_attention: bool = False
dtype: jnp.dtype = jnp.float32
use_memory_efficient_attention: bool = False
split_head_dim: bool = False
def setup(self):
# self attention (or cross_attention if only_cross_attention is True)
self.attn1 = FlaxAttention(
self.dim, self.n_heads, self.d_head, self.dropout, self.use_memory_efficient_attention, dtype=self.dtype
self.dim,
self.n_heads,
self.d_head,
self.dropout,
self.use_memory_efficient_attention,
self.split_head_dim,
dtype=self.dtype,
)
# cross attention
self.attn2 = FlaxAttention(
self.dim, self.n_heads, self.d_head, self.dropout, self.use_memory_efficient_attention, dtype=self.dtype
self.dim,
self.n_heads,
self.d_head,
self.dropout,
self.use_memory_efficient_attention,
self.split_head_dim,
dtype=self.dtype,
)
self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
......@@ -327,6 +343,9 @@ class FlaxTransformer2DModel(nn.Module):
Parameters `dtype`
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682
split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
"""
in_channels: int
n_heads: int
......@@ -337,6 +356,7 @@ class FlaxTransformer2DModel(nn.Module):
only_cross_attention: bool = False
dtype: jnp.dtype = jnp.float32
use_memory_efficient_attention: bool = False
split_head_dim: bool = False
def setup(self):
self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)
......@@ -362,6 +382,7 @@ class FlaxTransformer2DModel(nn.Module):
only_cross_attention=self.only_cross_attention,
dtype=self.dtype,
use_memory_efficient_attention=self.use_memory_efficient_attention,
split_head_dim=self.split_head_dim,
)
for _ in range(self.depth)
]
......
......@@ -39,6 +39,9 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
Whether to add downsampling layer before each final output
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682
split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
......@@ -51,6 +54,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
use_linear_projection: bool = False
only_cross_attention: bool = False
use_memory_efficient_attention: bool = False
split_head_dim: bool = False
dtype: jnp.dtype = jnp.float32
transformer_layers_per_block: int = 1
......@@ -77,6 +81,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
use_linear_projection=self.use_linear_projection,
only_cross_attention=self.only_cross_attention,
use_memory_efficient_attention=self.use_memory_efficient_attention,
split_head_dim=self.split_head_dim,
dtype=self.dtype,
)
attentions.append(attn_block)
......@@ -179,6 +184,9 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
Whether to add upsampling layer before each final output
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682
split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
......@@ -192,6 +200,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
use_linear_projection: bool = False
only_cross_attention: bool = False
use_memory_efficient_attention: bool = False
split_head_dim: bool = False
dtype: jnp.dtype = jnp.float32
transformer_layers_per_block: int = 1
......@@ -219,6 +228,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
use_linear_projection=self.use_linear_projection,
only_cross_attention=self.only_cross_attention,
use_memory_efficient_attention=self.use_memory_efficient_attention,
split_head_dim=self.split_head_dim,
dtype=self.dtype,
)
attentions.append(attn_block)
......@@ -323,6 +333,9 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
Number of attention heads of each spatial transformer block
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682
split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
......@@ -332,6 +345,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
num_attention_heads: int = 1
use_linear_projection: bool = False
use_memory_efficient_attention: bool = False
split_head_dim: bool = False
dtype: jnp.dtype = jnp.float32
transformer_layers_per_block: int = 1
......@@ -356,6 +370,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
depth=self.transformer_layers_per_block,
use_linear_projection=self.use_linear_projection,
use_memory_efficient_attention=self.use_memory_efficient_attention,
split_head_dim=self.split_head_dim,
dtype=self.dtype,
)
attentions.append(attn_block)
......
......@@ -92,6 +92,9 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
Enable memory efficient attention as described [here](https://arxiv.org/abs/2112.05682).
split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
"""
sample_size: int = 32
......@@ -116,6 +119,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
flip_sin_to_cos: bool = True
freq_shift: int = 0
use_memory_efficient_attention: bool = False
split_head_dim: bool = False
transformer_layers_per_block: Union[int, Tuple[int]] = 1
addition_embed_type: Optional[str] = None
addition_time_embed_dim: Optional[int] = None
......@@ -231,6 +235,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
use_linear_projection=self.use_linear_projection,
only_cross_attention=only_cross_attention[i],
use_memory_efficient_attention=self.use_memory_efficient_attention,
split_head_dim=self.split_head_dim,
dtype=self.dtype,
)
else:
......@@ -254,6 +259,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
transformer_layers_per_block=transformer_layers_per_block[-1],
use_linear_projection=self.use_linear_projection,
use_memory_efficient_attention=self.use_memory_efficient_attention,
split_head_dim=self.split_head_dim,
dtype=self.dtype,
)
......@@ -284,6 +290,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
use_linear_projection=self.use_linear_projection,
only_cross_attention=only_cross_attention[i],
use_memory_efficient_attention=self.use_memory_efficient_attention,
split_head_dim=self.split_head_dim,
dtype=self.dtype,
)
else:
......
......@@ -323,6 +323,7 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
revision = kwargs.pop("revision", None)
from_pt = kwargs.pop("from_pt", False)
use_memory_efficient_attention = kwargs.pop("use_memory_efficient_attention", False)
split_head_dim = kwargs.pop("split_head_dim", False)
dtype = kwargs.pop("dtype", None)
# 1. Download the checkpoints and configs
......@@ -501,6 +502,7 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
loadable_folder,
from_pt=from_pt,
use_memory_efficient_attention=use_memory_efficient_attention,
split_head_dim=split_head_dim,
dtype=dtype,
)
params[name] = loaded_params
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment