Unverified Commit 4a343077 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

add: utility to format our docs too 📜 (#7314)

* add: utility to format our docs too 📜

* debugging saga

* fix: message

* checking

* should be fixed.

* revert pipeline_fixture

* remove empty line

* make style

* fix: setup.py

* style.
parent 8e963d1c
...@@ -32,9 +32,7 @@ jobs: ...@@ -32,9 +32,7 @@ jobs:
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install .[quality] pip install .[quality]
- name: Check quality - name: Check quality
run: | run: make quality
ruff check examples tests src utils scripts
ruff format examples tests src utils scripts --check
- name: Check if failure - name: Check if failure
if: ${{ failure() }} if: ${{ failure() }}
run: | run: |
...@@ -53,7 +51,7 @@ jobs: ...@@ -53,7 +51,7 @@ jobs:
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install .[quality] pip install .[quality]
- name: Check quality - name: Check repo consistency
run: | run: |
python utils/check_copies.py python utils/check_copies.py
python utils/check_dummies.py python utils/check_dummies.py
......
...@@ -40,9 +40,7 @@ jobs: ...@@ -40,9 +40,7 @@ jobs:
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install .[quality] pip install .[quality]
- name: Check quality - name: Check quality
run: | run: make quality
ruff check examples tests src utils scripts
ruff format examples tests src utils scripts --check
- name: Check if failure - name: Check if failure
if: ${{ failure() }} if: ${{ failure() }}
run: | run: |
...@@ -61,7 +59,7 @@ jobs: ...@@ -61,7 +59,7 @@ jobs:
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install .[quality] pip install .[quality]
- name: Check quality - name: Check repo consistency
run: | run: |
python utils/check_copies.py python utils/check_copies.py
python utils/check_dummies.py python utils/check_dummies.py
......
...@@ -42,6 +42,7 @@ repo-consistency: ...@@ -42,6 +42,7 @@ repo-consistency:
quality: quality:
ruff check $(check_dirs) setup.py ruff check $(check_dirs) setup.py
ruff format --check $(check_dirs) setup.py ruff format --check $(check_dirs) setup.py
doc-builder style src/diffusers docs/source --max_len 119 --check_only
python utils/check_doc_toc.py python utils/check_doc_toc.py
# Format source code automatically and check is there are any problems left that need manual fixing # Format source code automatically and check is there are any problems left that need manual fixing
...@@ -55,6 +56,7 @@ extra_style_checks: ...@@ -55,6 +56,7 @@ extra_style_checks:
style: style:
ruff check $(check_dirs) setup.py --fix ruff check $(check_dirs) setup.py --fix
ruff format $(check_dirs) setup.py ruff format $(check_dirs) setup.py
doc-builder style src/diffusers docs/source --max_len 119
${MAKE} autogenerate_code ${MAKE} autogenerate_code
${MAKE} extra_style_checks ${MAKE} extra_style_checks
......
...@@ -134,6 +134,7 @@ _deps = [ ...@@ -134,6 +134,7 @@ _deps = [
"torchvision", "torchvision",
"transformers>=4.25.1", "transformers>=4.25.1",
"urllib3<=2.0.0", "urllib3<=2.0.0",
"black",
] ]
# this is a lookup table with items like: # this is a lookup table with items like:
......
...@@ -42,4 +42,5 @@ deps = { ...@@ -42,4 +42,5 @@ deps = {
"torchvision": "torchvision", "torchvision": "torchvision",
"transformers": "transformers>=4.25.1", "transformers": "transformers>=4.25.1",
"urllib3": "urllib3<=2.0.0", "urllib3": "urllib3<=2.0.0",
"black": "black",
} }
...@@ -173,8 +173,9 @@ class VaeImageProcessor(ConfigMixin): ...@@ -173,8 +173,9 @@ class VaeImageProcessor(ConfigMixin):
@staticmethod @staticmethod
def get_crop_region(mask_image: PIL.Image.Image, width: int, height: int, pad=0): def get_crop_region(mask_image: PIL.Image.Image, width: int, height: int, pad=0):
""" """
Finds a rectangular region that contains all masked ares in an image, and expands region to match the aspect ratio of the original image; Finds a rectangular region that contains all masked ares in an image, and expands region to match the aspect
for example, if user drew mask in a 128x32 region, and the dimensions for processing are 512x512, the region will be expanded to 128x128. ratio of the original image; for example, if user drew mask in a 128x32 region, and the dimensions for
processing are 512x512, the region will be expanded to 128x128.
Args: Args:
mask_image (PIL.Image.Image): Mask image. mask_image (PIL.Image.Image): Mask image.
...@@ -183,7 +184,8 @@ class VaeImageProcessor(ConfigMixin): ...@@ -183,7 +184,8 @@ class VaeImageProcessor(ConfigMixin):
pad (int, optional): Padding to be added to the crop region. Defaults to 0. pad (int, optional): Padding to be added to the crop region. Defaults to 0.
Returns: Returns:
tuple: (x1, y1, x2, y2) represent a rectangular region that contains all masked ares in an image and matches the original aspect ratio. tuple: (x1, y1, x2, y2) represent a rectangular region that contains all masked ares in an image and
matches the original aspect ratio.
""" """
mask_image = mask_image.convert("L") mask_image = mask_image.convert("L")
...@@ -265,7 +267,8 @@ class VaeImageProcessor(ConfigMixin): ...@@ -265,7 +267,8 @@ class VaeImageProcessor(ConfigMixin):
height: int, height: int,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
""" """
Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image. Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
the image within the dimensions, filling empty with data from image.
Args: Args:
image: The image to resize. image: The image to resize.
...@@ -309,7 +312,8 @@ class VaeImageProcessor(ConfigMixin): ...@@ -309,7 +312,8 @@ class VaeImageProcessor(ConfigMixin):
height: int, height: int,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
""" """
Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess. Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
the image within the dimensions, cropping the excess.
Args: Args:
image: The image to resize. image: The image to resize.
...@@ -346,12 +350,12 @@ class VaeImageProcessor(ConfigMixin): ...@@ -346,12 +350,12 @@ class VaeImageProcessor(ConfigMixin):
The width to resize to. The width to resize to.
resize_mode (`str`, *optional*, defaults to `default`): resize_mode (`str`, *optional*, defaults to `default`):
The resize mode to use, can be one of `default` or `fill`. If `default`, will resize the image to fit The resize mode to use, can be one of `default` or `fill`. If `default`, will resize the image to fit
within the specified width and height, and it may not maintaining the original aspect ratio. within the specified width and height, and it may not maintaining the original aspect ratio. If `fill`,
If `fill`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image will resize the image to fit within the specified width and height, maintaining the aspect ratio, and
within the dimensions, filling empty with data from image. then center the image within the dimensions, filling empty with data from image. If `crop`, will resize
If `crop`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image the image to fit within the specified width and height, maintaining the aspect ratio, and then center
within the dimensions, cropping the excess. the image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
Note that resize_mode `fill` and `crop` are only supported for PIL image input. supported for PIL image input.
Returns: Returns:
`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`: `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
...@@ -456,19 +460,21 @@ class VaeImageProcessor(ConfigMixin): ...@@ -456,19 +460,21 @@ class VaeImageProcessor(ConfigMixin):
Args: Args:
image (`pipeline_image_input`): image (`pipeline_image_input`):
The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of supported formats. The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of
supported formats.
height (`int`, *optional*, defaults to `None`): height (`int`, *optional*, defaults to `None`):
The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default height. The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default
height.
width (`int`, *optional*`, defaults to `None`): width (`int`, *optional*`, defaults to `None`):
The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width. The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width.
resize_mode (`str`, *optional*, defaults to `default`): resize_mode (`str`, *optional*, defaults to `default`):
The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within
within the specified width and height, and it may not maintaining the original aspect ratio. the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, will
If `fill`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image resize the image to fit within the specified width and height, maintaining the aspect ratio, and then
within the dimensions, filling empty with data from image. center the image within the dimensions, filling empty with data from image. If `crop`, will resize the
If `crop`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image image to fit within the specified width and height, maintaining the aspect ratio, and then center the
within the dimensions, cropping the excess. image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
Note that resize_mode `fill` and `crop` are only supported for PIL image input. supported for PIL image input.
crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`): crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`):
The crop coordinates for each image in the batch. If `None`, will not crop the image. The crop coordinates for each image in the batch. If `None`, will not crop the image.
""" """
...@@ -930,8 +936,8 @@ class IPAdapterMaskProcessor(VaeImageProcessor): ...@@ -930,8 +936,8 @@ class IPAdapterMaskProcessor(VaeImageProcessor):
@staticmethod @staticmethod
def downsample(mask: torch.FloatTensor, batch_size: int, num_queries: int, value_embed_dim: int): def downsample(mask: torch.FloatTensor, batch_size: int, num_queries: int, value_embed_dim: int):
""" """
Downsamples the provided mask tensor to match the expected dimensions for scaled dot-product attention. Downsamples the provided mask tensor to match the expected dimensions for scaled dot-product attention. If the
If the aspect ratio of the mask does not match the aspect ratio of the output image, a warning is issued. aspect ratio of the mask does not match the aspect ratio of the output image, a warning is issued.
Args: Args:
mask (`torch.FloatTensor`): mask (`torch.FloatTensor`):
......
...@@ -67,17 +67,18 @@ class IPAdapterMixin: ...@@ -67,17 +67,18 @@ class IPAdapterMixin:
- A [torch state - A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
subfolder (`str` or `List[str]`): subfolder (`str` or `List[str]`):
The subfolder location of a model file within a larger model repository on the Hub or locally. The subfolder location of a model file within a larger model repository on the Hub or locally. If a
If a list is passed, it should have the same length as `weight_name`. list is passed, it should have the same length as `weight_name`.
weight_name (`str` or `List[str]`): weight_name (`str` or `List[str]`):
The name of the weight file to load. If a list is passed, it should have the same length as The name of the weight file to load. If a list is passed, it should have the same length as
`weight_name`. `weight_name`.
image_encoder_folder (`str`, *optional*, defaults to `image_encoder`): image_encoder_folder (`str`, *optional*, defaults to `image_encoder`):
The subfolder location of the image encoder within a larger model repository on the Hub or locally. The subfolder location of the image encoder within a larger model repository on the Hub or locally.
Pass `None` to not load the image encoder. If the image encoder is located in a folder inside `subfolder`, Pass `None` to not load the image encoder. If the image encoder is located in a folder inside
you only need to pass the name of the folder that contains image encoder weights, e.g. `image_encoder_folder="image_encoder"`. `subfolder`, you only need to pass the name of the folder that contains image encoder weights, e.g.
If the image encoder is located in a folder other than `subfolder`, you should pass the path to the folder that contains image encoder weights, `image_encoder_folder="image_encoder"`. If the image encoder is located in a folder other than
for example, `image_encoder_folder="different_subfolder/image_encoder"`. `subfolder`, you should pass the path to the folder that contains image encoder weights, for example,
`image_encoder_folder="different_subfolder/image_encoder"`.
cache_dir (`Union[str, os.PathLike]`, *optional*): cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used. is not used.
......
...@@ -20,7 +20,8 @@ from ..utils import MIN_PEFT_VERSION, check_peft_version, is_peft_available ...@@ -20,7 +20,8 @@ from ..utils import MIN_PEFT_VERSION, check_peft_version, is_peft_available
class PeftAdapterMixin: class PeftAdapterMixin:
""" """
A class containing all functions for loading and using adapters weights that are supported in PEFT library. For A class containing all functions for loading and using adapters weights that are supported in PEFT library. For
more details about adapters and injecting them in a transformer-based model, check out the PEFT [documentation](https://huggingface.co/docs/peft/index). more details about adapters and injecting them in a transformer-based model, check out the PEFT
[documentation](https://huggingface.co/docs/peft/index).
Install the latest version of PEFT, and use this mixin to: Install the latest version of PEFT, and use this mixin to:
...@@ -143,8 +144,8 @@ class PeftAdapterMixin: ...@@ -143,8 +144,8 @@ class PeftAdapterMixin:
def enable_adapters(self) -> None: def enable_adapters(self) -> None:
""" """
Enable adapters that are attached to the model. The model uses `self.active_adapters()` to retrieve the Enable adapters that are attached to the model. The model uses `self.active_adapters()` to retrieve the list of
list of adapters to enable. adapters to enable.
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
[documentation](https://huggingface.co/docs/peft). [documentation](https://huggingface.co/docs/peft).
......
...@@ -198,19 +198,24 @@ class FromSingleFileMixin: ...@@ -198,19 +198,24 @@ class FromSingleFileMixin:
model_type (`str`, *optional*): model_type (`str`, *optional*):
The type of model to load. If not provided, the model type will be inferred from the checkpoint file. The type of model to load. If not provided, the model type will be inferred from the checkpoint file.
image_size (`int`, *optional*): image_size (`int`, *optional*):
The size of the image output. It's used to configure the `sample_size` parameter of the UNet and VAE model. The size of the image output. It's used to configure the `sample_size` parameter of the UNet and VAE
model.
load_safety_checker (`bool`, *optional*, defaults to `False`): load_safety_checker (`bool`, *optional*, defaults to `False`):
Whether to load the safety checker model or not. By default, the safety checker is not loaded unless a `safety_checker` component is passed to the `kwargs`. Whether to load the safety checker model or not. By default, the safety checker is not loaded unless a
`safety_checker` component is passed to the `kwargs`.
num_in_channels (`int`, *optional*): num_in_channels (`int`, *optional*):
Specify the number of input channels for the UNet model. Read more about how to configure UNet model with this parameter Specify the number of input channels for the UNet model. Read more about how to configure UNet model
with this parameter
[here](https://huggingface.co/docs/diffusers/training/adapt_a_model#configure-unet2dconditionmodel-parameters). [here](https://huggingface.co/docs/diffusers/training/adapt_a_model#configure-unet2dconditionmodel-parameters).
scaling_factor (`float`, *optional*): scaling_factor (`float`, *optional*):
The scaling factor to use for the VAE model. If not provided, it is inferred from the config file first. The scaling factor to use for the VAE model. If not provided, it is inferred from the config file
If the scaling factor is not found in the config file, the default value 0.18215 is used. first. If the scaling factor is not found in the config file, the default value 0.18215 is used.
scheduler_type (`str`, *optional*): scheduler_type (`str`, *optional*):
The type of scheduler to load. If not provided, the scheduler type will be inferred from the checkpoint file. The type of scheduler to load. If not provided, the scheduler type will be inferred from the checkpoint
file.
prediction_type (`str`, *optional*): prediction_type (`str`, *optional*):
The type of prediction to load. If not provided, the prediction type will be inferred from the checkpoint file. The type of prediction to load. If not provided, the prediction type will be inferred from the
checkpoint file.
kwargs (remaining dictionary of keyword arguments, *optional*): kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
class). The overwritten components are passed directly to the pipelines `__init__` method. See example class). The overwritten components are passed directly to the pipelines `__init__` method. See example
......
...@@ -487,20 +487,35 @@ class TextualInversionLoaderMixin: ...@@ -487,20 +487,35 @@ class TextualInversionLoaderMixin:
# Example 3: unload from SDXL # Example 3: unload from SDXL
pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
embedding_path = hf_hub_download(repo_id="linoyts/web_y2k", filename="web_y2k_emb.safetensors", repo_type="model") embedding_path = hf_hub_download(
repo_id="linoyts/web_y2k", filename="web_y2k_emb.safetensors", repo_type="model"
)
# load embeddings to the text encoders # load embeddings to the text encoders
state_dict = load_file(embedding_path) state_dict = load_file(embedding_path)
# load embeddings of text_encoder 1 (CLIP ViT-L/14) # load embeddings of text_encoder 1 (CLIP ViT-L/14)
pipeline.load_textual_inversion(state_dict["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer) pipeline.load_textual_inversion(
state_dict["clip_l"],
token=["<s0>", "<s1>"],
text_encoder=pipeline.text_encoder,
tokenizer=pipeline.tokenizer,
)
# load embeddings of text_encoder 2 (CLIP ViT-G/14) # load embeddings of text_encoder 2 (CLIP ViT-G/14)
pipeline.load_textual_inversion(state_dict["clip_g"], token=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2) pipeline.load_textual_inversion(
state_dict["clip_g"],
token=["<s0>", "<s1>"],
text_encoder=pipeline.text_encoder_2,
tokenizer=pipeline.tokenizer_2,
)
# Unload explicitly from both text encoders abd tokenizers # Unload explicitly from both text encoders abd tokenizers
pipeline.unload_textual_inversion(tokens=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer) pipeline.unload_textual_inversion(
pipeline.unload_textual_inversion(tokens=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2) tokens=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer
)
pipeline.unload_textual_inversion(
tokens=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2
)
``` ```
""" """
......
...@@ -74,37 +74,24 @@ def _maybe_expand_lora_scales_for_one_adapter( ...@@ -74,37 +74,24 @@ def _maybe_expand_lora_scales_for_one_adapter(
E.g. turns E.g. turns
```python ```python
scales = { scales = {"down": 2, "mid": 3, "up": {"block_0": 4, "block_1": [5, 6, 7]}}
'down': 2, blocks_with_transformer = {"down": [1, 2], "up": [0, 1]}
'mid': 3, transformer_per_block = {"down": 2, "up": 3}
'up': {
'block_0': 4,
'block_1': [5, 6, 7]
}
}
blocks_with_transformer = {
'down': [1,2],
'up': [0,1]
}
transformer_per_block = {
'down': 2,
'up': 3
}
``` ```
into into
```python ```python
{ {
'down.block_1.0': 2, "down.block_1.0": 2,
'down.block_1.1': 2, "down.block_1.1": 2,
'down.block_2.0': 2, "down.block_2.0": 2,
'down.block_2.1': 2, "down.block_2.1": 2,
'mid': 3, "mid": 3,
'up.block_0.0': 4, "up.block_0.0": 4,
'up.block_0.1': 4, "up.block_0.1": 4,
'up.block_0.2': 4, "up.block_0.2": 4,
'up.block_1.0': 5, "up.block_1.0": 5,
'up.block_1.1': 6, "up.block_1.1": 6,
'up.block_1.2': 7, "up.block_1.2": 7,
} }
``` ```
""" """
......
...@@ -1298,9 +1298,9 @@ class AttnProcessor2_0: ...@@ -1298,9 +1298,9 @@ class AttnProcessor2_0:
class FusedAttnProcessor2_0: class FusedAttnProcessor2_0:
r""" r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses
It uses fused projection layers. For self-attention modules, all projection matrices (i.e., query, fused projection layers. For self-attention modules, all projection matrices (i.e., query, key, value) are fused.
key, value) are fused. For cross-attention modules, key and value projection matrices are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}> <Tip warning={true}>
......
...@@ -453,8 +453,8 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin): ...@@ -453,8 +453,8 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
def fuse_qkv_projections(self): def fuse_qkv_projections(self):
""" """
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
key, value) are fused. For cross-attention modules, key and value projection matrices are fused. are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}> <Tip warning={true}>
......
...@@ -329,15 +329,15 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -329,15 +329,15 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
controlnet_cond (`jnp.ndarray`): (batch, channel, height, width) the conditional input tensor controlnet_cond (`jnp.ndarray`): (batch, channel, height, width) the conditional input tensor
conditioning_scale (`float`, *optional*, defaults to `1.0`): the scale factor for controlnet outputs conditioning_scale (`float`, *optional*, defaults to `1.0`): the scale factor for controlnet outputs
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a Whether or not to return a [`models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of
plain tuple. a plain tuple.
train (`bool`, *optional*, defaults to `False`): train (`bool`, *optional*, defaults to `False`):
Use deterministic functions and disable dropout when not training. Use deterministic functions and disable dropout when not training.
Returns: Returns:
[`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`: [`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
[`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a [`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise
`tuple`. When returning a tuple, the first element is the sample tensor. a `tuple`. When returning a tuple, the first element is the sample tensor.
""" """
channel_order = self.controlnet_conditioning_channel_order channel_order = self.controlnet_conditioning_channel_order
if channel_order == "bgr": if channel_order == "bgr":
......
...@@ -795,16 +795,13 @@ class IPAdapterPlusImageProjection(nn.Module): ...@@ -795,16 +795,13 @@ class IPAdapterPlusImageProjection(nn.Module):
Args: Args:
---- ----
embed_dims (int): The feature dimension. Defaults to 768. embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels,
output_dims (int): The number of output channels, that is the same that is the same
number of the channels in the number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024.
`unet.config.cross_attention_dim`. Defaults to 1024. hidden_dims (int): The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults
hidden_dims (int): The number of hidden channels. Defaults to 1280. to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads.
depth (int): The number of blocks. Defaults to 8. Defaults to 16. num_queries (int): The number of queries. Defaults to 8. ffn_ratio (float): The expansion ratio
dim_head (int): The number of head channels. Defaults to 64. of feedforward network hidden
heads (int): Parallel attention heads. Defaults to 16.
num_queries (int): The number of queries. Defaults to 8.
ffn_ratio (float): The expansion ratio of feedforward network hidden
layer channels. Defaults to 4. layer channels. Defaults to 4.
""" """
......
...@@ -202,8 +202,8 @@ class ResnetBlock2D(nn.Module): ...@@ -202,8 +202,8 @@ class ResnetBlock2D(nn.Module):
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use. non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config. time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" for a
for a stronger conditioning with scale and shift. stronger conditioning with scale and shift.
kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
[`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`]. [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output. output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
......
...@@ -120,7 +120,8 @@ class DualTransformer2DModel(nn.Module): ...@@ -120,7 +120,8 @@ class DualTransformer2DModel(nn.Module):
`self.processor` in `self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Whether or not to return a [`models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
tuple.
Returns: Returns:
[`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`: [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
......
...@@ -294,8 +294,8 @@ class TransformerSpatioTemporalModel(nn.Module): ...@@ -294,8 +294,8 @@ class TransformerSpatioTemporalModel(nn.Module):
A tensor indicating whether the input contains only images. 1 indicates that the input contains only A tensor indicating whether the input contains only images. 1 indicates that the input contains only
images, 0 indicates that the input contains video frames. images, 0 indicates that the input contains video frames.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_temporal.TransformerTemporalModelOutput`] instead of a plain Whether or not to return a [`~models.transformer_temporal.TransformerTemporalModelOutput`] instead of a
tuple. plain tuple.
Returns: Returns:
[`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`: [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
......
...@@ -865,8 +865,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, ...@@ -865,8 +865,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
def fuse_qkv_projections(self): def fuse_qkv_projections(self):
""" """
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
key, value) are fused. For cross-attention modules, key and value projection matrices are fused. are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}> <Tip warning={true}>
...@@ -1093,8 +1093,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, ...@@ -1093,8 +1093,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
Returns: Returns:
[`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned,
a `tuple` is returned where the first element is the sample tensor. otherwise a `tuple` is returned where the first element is the sample tensor.
""" """
# By default samples have to be AT least a multiple of the overall upsampling factor. # By default samples have to be AT least a multiple of the overall upsampling factor.
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers). # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
......
...@@ -76,7 +76,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -76,7 +76,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
up_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D")`): up_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D")`):
The tuple of upsample blocks to use. The tuple of upsample blocks to use.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`. If `None`, the mid block layer is skipped. Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`. If `None`, the mid block layer
is skipped.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each block. The tuple of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2): layers_per_block (`int`, *optional*, defaults to 2):
...@@ -350,15 +351,15 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -350,15 +351,15 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
mid_block_additional_residual: (`torch.Tensor`, *optional*): mid_block_additional_residual: (`torch.Tensor`, *optional*):
A tensor that if specified is added to the residual of the middle unet block. A tensor that if specified is added to the residual of the middle unet block.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a Whether or not to return a [`models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of
plain tuple. a plain tuple.
train (`bool`, *optional*, defaults to `False`): train (`bool`, *optional*, defaults to `False`):
Use deterministic functions and disable dropout when not training. Use deterministic functions and disable dropout when not training.
Returns: Returns:
[`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`: [`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
[`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. [`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a
When returning a tuple, the first element is the sample tensor. `tuple`. When returning a tuple, the first element is the sample tensor.
""" """
# 1. time # 1. time
if not isinstance(timesteps, jnp.ndarray): if not isinstance(timesteps, jnp.ndarray):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment