Unverified Commit c8bb1ff5 authored by Quentin Gallouédec's avatar Quentin Gallouédec Committed by GitHub
Browse files

Use HF Papers (#11567)



* Use HF Papers

* Apply style fixes

---------
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
parent 799adf4a
...@@ -83,8 +83,8 @@ class AutoencoderTiny(ModelMixin, ConfigMixin): ...@@ -83,8 +83,8 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. For this Autoencoder, Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) paper. For this
however, no such scaling factor was used, hence the value of 1.0 as the default. Autoencoder, however, no such scaling factor was used, hence the value of 1.0 as the default.
force_upcast (`bool`, *optional*, default to `False`): force_upcast (`bool`, *optional*, default to `False`):
If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
can be fine-tuned / trained to a lower range without losing too much precision, in which case can be fine-tuned / trained to a lower range without losing too much precision, in which case
......
...@@ -66,7 +66,7 @@ class VQModel(ModelMixin, ConfigMixin): ...@@ -66,7 +66,7 @@ class VQModel(ModelMixin, ConfigMixin):
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) paper.
norm_type (`str`, *optional*, defaults to `"group"`): norm_type (`str`, *optional*, defaults to `"group"`):
Type of normalization layer to use. Can be one of `"group"` or `"spatial"`. Type of normalization layer to use. Can be one of `"group"` or `"spatial"`.
""" """
......
...@@ -63,8 +63,8 @@ class ControlNetOutput(BaseOutput): ...@@ -63,8 +63,8 @@ class ControlNetOutput(BaseOutput):
class ControlNetConditioningEmbedding(nn.Module): class ControlNetConditioningEmbedding(nn.Module):
""" """
Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN Quoting from https://huggingface.co/papers/2302.05543: "Stable Diffusion uses a pre-processing method similar to
[11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized VQ-GAN [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
(activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
......
...@@ -103,7 +103,7 @@ class HunyuanDiT2DControlNetModel(ModelMixin, ConfigMixin): ...@@ -103,7 +103,7 @@ class HunyuanDiT2DControlNetModel(ModelMixin, ConfigMixin):
activation_fn=activation_fn, activation_fn=activation_fn,
ff_inner_dim=int(self.inner_dim * mlp_ratio), ff_inner_dim=int(self.inner_dim * mlp_ratio),
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details. qk_norm=True, # See https://huggingface.co/papers/2302.05442 for details.
skip=False, # always False as it is the first half of the model skip=False, # always False as it is the first half of the model
) )
for layer in range(transformer_num_layers // 2 - 1) for layer in range(transformer_num_layers // 2 - 1)
......
...@@ -96,7 +96,7 @@ class SparseControlNetConditioningEmbedding(nn.Module): ...@@ -96,7 +96,7 @@ class SparseControlNetConditioningEmbedding(nn.Module):
class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
""" """
A SparseControlNet model as described in [SparseCtrl: Adding Sparse Controls to Text-to-Video Diffusion A SparseControlNet model as described in [SparseCtrl: Adding Sparse Controls to Text-to-Video Diffusion
Models](https://arxiv.org/abs/2311.16933). Models](https://huggingface.co/papers/2311.16933).
Args: Args:
in_channels (`int`, defaults to 4): in_channels (`int`, defaults to 4):
......
...@@ -942,7 +942,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin): ...@@ -942,7 +942,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. r"""Enables the FreeU mechanism from https://huggingface.co/papers/2309.11497.
The suffixes after the scaling factors represent the stage blocks where they are being applied. The suffixes after the scaling factors represent the stage blocks where they are being applied.
......
...@@ -1401,7 +1401,7 @@ class ImagePositionalEmbeddings(nn.Module): ...@@ -1401,7 +1401,7 @@ class ImagePositionalEmbeddings(nn.Module):
Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
height and width of the latent space. height and width of the latent space.
For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092 For more details, see figure 10 of the dall-e paper: https://huggingface.co/papers/2102.12092
For VQ-diffusion: For VQ-diffusion:
......
...@@ -89,7 +89,7 @@ class FlaxTimestepEmbedding(nn.Module): ...@@ -89,7 +89,7 @@ class FlaxTimestepEmbedding(nn.Module):
class FlaxTimesteps(nn.Module): class FlaxTimesteps(nn.Module):
r""" r"""
Wrapper Module for sinusoidal Time step Embeddings as described in https://arxiv.org/abs/2006.11239 Wrapper Module for sinusoidal Time step Embeddings as described in https://huggingface.co/papers/2006.11239
Args: Args:
dim (`int`, *optional*, defaults to `32`): dim (`int`, *optional*, defaults to `32`):
......
...@@ -237,7 +237,7 @@ class AdaLayerNormSingle(nn.Module): ...@@ -237,7 +237,7 @@ class AdaLayerNormSingle(nn.Module):
r""" r"""
Norm layer adaptive layer norm single (adaLN-single). Norm layer adaptive layer norm single (adaLN-single).
As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3). As proposed in PixArt-Alpha (see: https://huggingface.co/papers/2310.00426; Section 2.3).
Parameters: Parameters:
embedding_dim (`int`): The size of each embedding vector. embedding_dim (`int`): The size of each embedding vector.
...@@ -510,7 +510,7 @@ else: ...@@ -510,7 +510,7 @@ else:
class RMSNorm(nn.Module): class RMSNorm(nn.Module):
r""" r"""
RMS Norm as introduced in https://arxiv.org/abs/1910.07467 by Zhang et al. RMS Norm as introduced in https://huggingface.co/papers/1910.07467 by Zhang et al.
Args: Args:
dim (`int`): Number of dimensions to use for `weights`. Only effective when `elementwise_affine` is True. dim (`int`): Number of dimensions to use for `weights`. Only effective when `elementwise_affine` is True.
...@@ -600,7 +600,7 @@ class MochiRMSNorm(nn.Module): ...@@ -600,7 +600,7 @@ class MochiRMSNorm(nn.Module):
class GlobalResponseNorm(nn.Module): class GlobalResponseNorm(nn.Module):
r""" r"""
Global response normalization as introduced in ConvNeXt-v2 (https://arxiv.org/abs/2301.00808). Global response normalization as introduced in ConvNeXt-v2 (https://huggingface.co/papers/2301.00808).
Args: Args:
dim (`int`): Number of dimensions to use for the `gamma` and `beta`. dim (`int`): Number of dimensions to use for the `gamma` and `beta`.
......
...@@ -359,7 +359,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From ...@@ -359,7 +359,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
self.norm_out = AuraFlowPreFinalBlock(self.inner_dim, self.inner_dim) self.norm_out = AuraFlowPreFinalBlock(self.inner_dim, self.inner_dim)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False) self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False)
# https://arxiv.org/abs/2309.16588 # https://huggingface.co/papers/2309.16588
# prevents artifacts in the attention maps # prevents artifacts in the attention maps
self.register_tokens = nn.Parameter(torch.randn(1, 8, self.inner_dim) * 0.02) self.register_tokens = nn.Parameter(torch.randn(1, 8, self.inner_dim) * 0.02)
......
...@@ -30,7 +30,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name ...@@ -30,7 +30,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class DiTTransformer2DModel(ModelMixin, ConfigMixin): class DiTTransformer2DModel(ModelMixin, ConfigMixin):
r""" r"""
A 2D Transformer model as introduced in DiT (https://arxiv.org/abs/2212.09748). A 2D Transformer model as introduced in DiT (https://huggingface.co/papers/2212.09748).
Parameters: Parameters:
num_attention_heads (int, optional, defaults to 16): The number of heads to use for multi-head attention. num_attention_heads (int, optional, defaults to 16): The number of heads to use for multi-head attention.
......
...@@ -308,7 +308,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin): ...@@ -308,7 +308,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
activation_fn=activation_fn, activation_fn=activation_fn,
ff_inner_dim=int(self.inner_dim * mlp_ratio), ff_inner_dim=int(self.inner_dim * mlp_ratio),
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details. qk_norm=True, # See https://huggingface.co/papers/2302.05442 for details.
skip=layer > num_layers // 2, skip=layer > num_layers // 2,
) )
for layer in range(num_layers) for layer in range(num_layers)
......
...@@ -30,7 +30,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin): ...@@ -30,7 +30,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
_supports_gradient_checkpointing = True _supports_gradient_checkpointing = True
""" """
A 3D Transformer model for video-like data, paper: https://arxiv.org/abs/2401.03048, official code: A 3D Transformer model for video-like data, paper: https://huggingface.co/papers/2401.03048, official code:
https://github.com/Vchitect/Latte https://github.com/Vchitect/Latte
Parameters: Parameters:
......
...@@ -31,8 +31,8 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name ...@@ -31,8 +31,8 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class PixArtTransformer2DModel(ModelMixin, ConfigMixin): class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
r""" r"""
A 2D Transformer model as introduced in PixArt family of models (https://arxiv.org/abs/2310.00426, A 2D Transformer model as introduced in PixArt family of models (https://huggingface.co/papers/2310.00426,
https://arxiv.org/abs/2403.04692). https://huggingface.co/papers/2403.04692).
Parameters: Parameters:
num_attention_heads (int, optional, defaults to 16): The number of heads to use for multi-head attention. num_attention_heads (int, optional, defaults to 16): The number of heads to use for multi-head attention.
......
...@@ -61,7 +61,7 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef ...@@ -61,7 +61,7 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef
added_emb_type (`str`, *optional*, defaults to `prd`): Additional embeddings to condition the model. added_emb_type (`str`, *optional*, defaults to `prd`): Additional embeddings to condition the model.
Choose from `prd` or `None`. if choose `prd`, it will prepend a token indicating the (quantized) dot Choose from `prd` or `None`. if choose `prd`, it will prepend a token indicating the (quantized) dot
product between the text embedding and image embedding as proposed in the unclip paper product between the text embedding and image embedding as proposed in the unclip paper
https://arxiv.org/abs/2204.06125 If it is `None`, no additional embeddings will be prepended. https://huggingface.co/papers/2204.06125 If it is `None`, no additional embeddings will be prepended.
time_embed_dim (`int, *optional*, defaults to None): The dimension of timestep embeddings. time_embed_dim (`int, *optional*, defaults to None): The dimension of timestep embeddings.
If None, will be set to `num_attention_heads * attention_head_dim` If None, will be set to `num_attention_heads * attention_head_dim`
embedding_proj_dim (`int`, *optional*, default to None): embedding_proj_dim (`int`, *optional*, default to None):
......
...@@ -390,7 +390,7 @@ class T5LayerNorm(nn.Module): ...@@ -390,7 +390,7 @@ class T5LayerNorm(nn.Module):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated # Square Layer Normalization https://huggingface.co/papers/1910.07467 thus variance is calculated
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
# half-precision inputs is done in fp32 # half-precision inputs is done in fp32
...@@ -407,7 +407,7 @@ class T5LayerNorm(nn.Module): ...@@ -407,7 +407,7 @@ class T5LayerNorm(nn.Module):
class NewGELUActivation(nn.Module): class NewGELUActivation(nn.Module):
""" """
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 the Gaussian Error Linear Units paper: https://huggingface.co/papers/1606.08415
""" """
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
......
...@@ -283,7 +283,7 @@ class OmniGenBlock(nn.Module): ...@@ -283,7 +283,7 @@ class OmniGenBlock(nn.Module):
class OmniGenTransformer2DModel(ModelMixin, ConfigMixin): class OmniGenTransformer2DModel(ModelMixin, ConfigMixin):
""" """
The Transformer model introduced in OmniGen (https://arxiv.org/pdf/2409.11340). The Transformer model introduced in OmniGen (https://huggingface.co/papers/2409.11340).
Parameters: Parameters:
in_channels (`int`, defaults to `4`): in_channels (`int`, defaults to `4`):
......
...@@ -22,7 +22,7 @@ from ..resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D ...@@ -22,7 +22,7 @@ from ..resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D
class FlaxCrossAttnDownBlock2D(nn.Module): class FlaxCrossAttnDownBlock2D(nn.Module):
r""" r"""
Cross Attention 2D Downsizing block - original architecture from Unet transformers: Cross Attention 2D Downsizing block - original architecture from Unet transformers:
https://arxiv.org/abs/2103.06104 https://huggingface.co/papers/2103.06104
Parameters: Parameters:
in_channels (:obj:`int`): in_channels (:obj:`int`):
...@@ -38,7 +38,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module): ...@@ -38,7 +38,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
add_downsample (:obj:`bool`, *optional*, defaults to `True`): add_downsample (:obj:`bool`, *optional*, defaults to `True`):
Whether to add downsampling layer before each final output Whether to add downsampling layer before each final output
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682 enable memory efficient attention https://huggingface.co/papers/2112.05682
split_head_dim (`bool`, *optional*, defaults to `False`): split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases, Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL. enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
...@@ -169,7 +169,7 @@ class FlaxDownBlock2D(nn.Module): ...@@ -169,7 +169,7 @@ class FlaxDownBlock2D(nn.Module):
class FlaxCrossAttnUpBlock2D(nn.Module): class FlaxCrossAttnUpBlock2D(nn.Module):
r""" r"""
Cross Attention 2D Upsampling block - original architecture from Unet transformers: Cross Attention 2D Upsampling block - original architecture from Unet transformers:
https://arxiv.org/abs/2103.06104 https://huggingface.co/papers/2103.06104
Parameters: Parameters:
in_channels (:obj:`int`): in_channels (:obj:`int`):
...@@ -185,7 +185,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module): ...@@ -185,7 +185,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
add_upsample (:obj:`bool`, *optional*, defaults to `True`): add_upsample (:obj:`bool`, *optional*, defaults to `True`):
Whether to add upsampling layer before each final output Whether to add upsampling layer before each final output
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682 enable memory efficient attention https://huggingface.co/papers/2112.05682
split_head_dim (`bool`, *optional*, defaults to `False`): split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases, Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL. enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
...@@ -324,7 +324,8 @@ class FlaxUpBlock2D(nn.Module): ...@@ -324,7 +324,8 @@ class FlaxUpBlock2D(nn.Module):
class FlaxUNetMidBlock2DCrossAttn(nn.Module): class FlaxUNetMidBlock2DCrossAttn(nn.Module):
r""" r"""
Cross Attention 2D Mid-level block - original architecture from Unet transformers: https://arxiv.org/abs/2103.06104 Cross Attention 2D Mid-level block - original architecture from Unet transformers:
https://huggingface.co/papers/2103.06104
Parameters: Parameters:
in_channels (:obj:`int`): in_channels (:obj:`int`):
...@@ -336,7 +337,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module): ...@@ -336,7 +337,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
num_attention_heads (:obj:`int`, *optional*, defaults to 1): num_attention_heads (:obj:`int`, *optional*, defaults to 1):
Number of attention heads of each spatial transformer block Number of attention heads of each spatial transformer block
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682 enable memory efficient attention https://huggingface.co/papers/2112.05682
split_head_dim (`bool`, *optional*, defaults to `False`): split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases, Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL. enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
......
...@@ -835,7 +835,7 @@ class UNet2DConditionModel( ...@@ -835,7 +835,7 @@ class UNet2DConditionModel(
fn_recursive_set_attention_slice(module, reversed_slice_size) fn_recursive_set_attention_slice(module, reversed_slice_size)
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. r"""Enables the FreeU mechanism from https://huggingface.co/papers/2309.11497.
The suffixes after the scaling factors represent the stage blocks where they are being applied. The suffixes after the scaling factors represent the stage blocks where they are being applied.
......
...@@ -94,7 +94,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -94,7 +94,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
Whether to flip the sin to cos in the time embedding. Whether to flip the sin to cos in the time embedding.
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
Enable memory efficient attention as described [here](https://arxiv.org/abs/2112.05682). Enable memory efficient attention as described [here](https://huggingface.co/papers/2112.05682).
split_head_dim (`bool`, *optional*, defaults to `False`): split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases, Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL. enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment