"vscode:/vscode.git/clone" did not exist on "e6a15c1a42a4792e39f7cfe99ce5c6c8ae5bbbb9"
Unverified Commit 536c297a authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Trickle down `split_head_dim` (#5208)

parent 693a0d08
...@@ -258,6 +258,9 @@ class FlaxBasicTransformerBlock(nn.Module): ...@@ -258,6 +258,9 @@ class FlaxBasicTransformerBlock(nn.Module):
Parameters `dtype` Parameters `dtype`
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682 enable memory efficient attention https://arxiv.org/abs/2112.05682
split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
""" """
dim: int dim: int
n_heads: int n_heads: int
...@@ -266,15 +269,28 @@ class FlaxBasicTransformerBlock(nn.Module): ...@@ -266,15 +269,28 @@ class FlaxBasicTransformerBlock(nn.Module):
only_cross_attention: bool = False only_cross_attention: bool = False
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
use_memory_efficient_attention: bool = False use_memory_efficient_attention: bool = False
split_head_dim: bool = False
def setup(self): def setup(self):
# self attention (or cross_attention if only_cross_attention is True) # self attention (or cross_attention if only_cross_attention is True)
self.attn1 = FlaxAttention( self.attn1 = FlaxAttention(
self.dim, self.n_heads, self.d_head, self.dropout, self.use_memory_efficient_attention, dtype=self.dtype self.dim,
self.n_heads,
self.d_head,
self.dropout,
self.use_memory_efficient_attention,
self.split_head_dim,
dtype=self.dtype,
) )
# cross attention # cross attention
self.attn2 = FlaxAttention( self.attn2 = FlaxAttention(
self.dim, self.n_heads, self.d_head, self.dropout, self.use_memory_efficient_attention, dtype=self.dtype self.dim,
self.n_heads,
self.d_head,
self.dropout,
self.use_memory_efficient_attention,
self.split_head_dim,
dtype=self.dtype,
) )
self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype) self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
...@@ -327,6 +343,9 @@ class FlaxTransformer2DModel(nn.Module): ...@@ -327,6 +343,9 @@ class FlaxTransformer2DModel(nn.Module):
Parameters `dtype` Parameters `dtype`
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682 enable memory efficient attention https://arxiv.org/abs/2112.05682
split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
""" """
in_channels: int in_channels: int
n_heads: int n_heads: int
...@@ -337,6 +356,7 @@ class FlaxTransformer2DModel(nn.Module): ...@@ -337,6 +356,7 @@ class FlaxTransformer2DModel(nn.Module):
only_cross_attention: bool = False only_cross_attention: bool = False
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
use_memory_efficient_attention: bool = False use_memory_efficient_attention: bool = False
split_head_dim: bool = False
def setup(self): def setup(self):
self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5) self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)
...@@ -362,6 +382,7 @@ class FlaxTransformer2DModel(nn.Module): ...@@ -362,6 +382,7 @@ class FlaxTransformer2DModel(nn.Module):
only_cross_attention=self.only_cross_attention, only_cross_attention=self.only_cross_attention,
dtype=self.dtype, dtype=self.dtype,
use_memory_efficient_attention=self.use_memory_efficient_attention, use_memory_efficient_attention=self.use_memory_efficient_attention,
split_head_dim=self.split_head_dim,
) )
for _ in range(self.depth) for _ in range(self.depth)
] ]
......
...@@ -39,6 +39,9 @@ class FlaxCrossAttnDownBlock2D(nn.Module): ...@@ -39,6 +39,9 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
Whether to add downsampling layer before each final output Whether to add downsampling layer before each final output
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682 enable memory efficient attention https://arxiv.org/abs/2112.05682
split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype` Parameters `dtype`
""" """
...@@ -51,6 +54,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module): ...@@ -51,6 +54,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
use_linear_projection: bool = False use_linear_projection: bool = False
only_cross_attention: bool = False only_cross_attention: bool = False
use_memory_efficient_attention: bool = False use_memory_efficient_attention: bool = False
split_head_dim: bool = False
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
transformer_layers_per_block: int = 1 transformer_layers_per_block: int = 1
...@@ -77,6 +81,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module): ...@@ -77,6 +81,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
use_linear_projection=self.use_linear_projection, use_linear_projection=self.use_linear_projection,
only_cross_attention=self.only_cross_attention, only_cross_attention=self.only_cross_attention,
use_memory_efficient_attention=self.use_memory_efficient_attention, use_memory_efficient_attention=self.use_memory_efficient_attention,
split_head_dim=self.split_head_dim,
dtype=self.dtype, dtype=self.dtype,
) )
attentions.append(attn_block) attentions.append(attn_block)
...@@ -179,6 +184,9 @@ class FlaxCrossAttnUpBlock2D(nn.Module): ...@@ -179,6 +184,9 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
Whether to add upsampling layer before each final output Whether to add upsampling layer before each final output
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682 enable memory efficient attention https://arxiv.org/abs/2112.05682
split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype` Parameters `dtype`
""" """
...@@ -192,6 +200,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module): ...@@ -192,6 +200,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
use_linear_projection: bool = False use_linear_projection: bool = False
only_cross_attention: bool = False only_cross_attention: bool = False
use_memory_efficient_attention: bool = False use_memory_efficient_attention: bool = False
split_head_dim: bool = False
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
transformer_layers_per_block: int = 1 transformer_layers_per_block: int = 1
...@@ -219,6 +228,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module): ...@@ -219,6 +228,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
use_linear_projection=self.use_linear_projection, use_linear_projection=self.use_linear_projection,
only_cross_attention=self.only_cross_attention, only_cross_attention=self.only_cross_attention,
use_memory_efficient_attention=self.use_memory_efficient_attention, use_memory_efficient_attention=self.use_memory_efficient_attention,
split_head_dim=self.split_head_dim,
dtype=self.dtype, dtype=self.dtype,
) )
attentions.append(attn_block) attentions.append(attn_block)
...@@ -323,6 +333,9 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module): ...@@ -323,6 +333,9 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
Number of attention heads of each spatial transformer block Number of attention heads of each spatial transformer block
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682 enable memory efficient attention https://arxiv.org/abs/2112.05682
split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype` Parameters `dtype`
""" """
...@@ -332,6 +345,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module): ...@@ -332,6 +345,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
num_attention_heads: int = 1 num_attention_heads: int = 1
use_linear_projection: bool = False use_linear_projection: bool = False
use_memory_efficient_attention: bool = False use_memory_efficient_attention: bool = False
split_head_dim: bool = False
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
transformer_layers_per_block: int = 1 transformer_layers_per_block: int = 1
...@@ -356,6 +370,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module): ...@@ -356,6 +370,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
depth=self.transformer_layers_per_block, depth=self.transformer_layers_per_block,
use_linear_projection=self.use_linear_projection, use_linear_projection=self.use_linear_projection,
use_memory_efficient_attention=self.use_memory_efficient_attention, use_memory_efficient_attention=self.use_memory_efficient_attention,
split_head_dim=self.split_head_dim,
dtype=self.dtype, dtype=self.dtype,
) )
attentions.append(attn_block) attentions.append(attn_block)
......
...@@ -92,6 +92,9 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -92,6 +92,9 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
Enable memory efficient attention as described [here](https://arxiv.org/abs/2112.05682). Enable memory efficient attention as described [here](https://arxiv.org/abs/2112.05682).
split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
""" """
sample_size: int = 32 sample_size: int = 32
...@@ -116,6 +119,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -116,6 +119,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
flip_sin_to_cos: bool = True flip_sin_to_cos: bool = True
freq_shift: int = 0 freq_shift: int = 0
use_memory_efficient_attention: bool = False use_memory_efficient_attention: bool = False
split_head_dim: bool = False
transformer_layers_per_block: Union[int, Tuple[int]] = 1 transformer_layers_per_block: Union[int, Tuple[int]] = 1
addition_embed_type: Optional[str] = None addition_embed_type: Optional[str] = None
addition_time_embed_dim: Optional[int] = None addition_time_embed_dim: Optional[int] = None
...@@ -231,6 +235,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -231,6 +235,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
use_linear_projection=self.use_linear_projection, use_linear_projection=self.use_linear_projection,
only_cross_attention=only_cross_attention[i], only_cross_attention=only_cross_attention[i],
use_memory_efficient_attention=self.use_memory_efficient_attention, use_memory_efficient_attention=self.use_memory_efficient_attention,
split_head_dim=self.split_head_dim,
dtype=self.dtype, dtype=self.dtype,
) )
else: else:
...@@ -254,6 +259,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -254,6 +259,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
transformer_layers_per_block=transformer_layers_per_block[-1], transformer_layers_per_block=transformer_layers_per_block[-1],
use_linear_projection=self.use_linear_projection, use_linear_projection=self.use_linear_projection,
use_memory_efficient_attention=self.use_memory_efficient_attention, use_memory_efficient_attention=self.use_memory_efficient_attention,
split_head_dim=self.split_head_dim,
dtype=self.dtype, dtype=self.dtype,
) )
...@@ -284,6 +290,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -284,6 +290,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
use_linear_projection=self.use_linear_projection, use_linear_projection=self.use_linear_projection,
only_cross_attention=only_cross_attention[i], only_cross_attention=only_cross_attention[i],
use_memory_efficient_attention=self.use_memory_efficient_attention, use_memory_efficient_attention=self.use_memory_efficient_attention,
split_head_dim=self.split_head_dim,
dtype=self.dtype, dtype=self.dtype,
) )
else: else:
......
...@@ -323,6 +323,7 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -323,6 +323,7 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
from_pt = kwargs.pop("from_pt", False) from_pt = kwargs.pop("from_pt", False)
use_memory_efficient_attention = kwargs.pop("use_memory_efficient_attention", False) use_memory_efficient_attention = kwargs.pop("use_memory_efficient_attention", False)
split_head_dim = kwargs.pop("split_head_dim", False)
dtype = kwargs.pop("dtype", None) dtype = kwargs.pop("dtype", None)
# 1. Download the checkpoints and configs # 1. Download the checkpoints and configs
...@@ -501,6 +502,7 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -501,6 +502,7 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
loadable_folder, loadable_folder,
from_pt=from_pt, from_pt=from_pt,
use_memory_efficient_attention=use_memory_efficient_attention, use_memory_efficient_attention=use_memory_efficient_attention,
split_head_dim=split_head_dim,
dtype=dtype, dtype=dtype,
) )
params[name] = loaded_params params[name] = loaded_params
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment